mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 18:23:45 +08:00
feat: auto translate and rewrite prompt for midjourney and stable-diffusion
This commit is contained in:
@@ -66,7 +66,6 @@ type MidJourneyPlusConfig struct {
|
||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
||||
ApiURL string // api 地址
|
||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||
CdnURL string // CDN 加速地址
|
||||
ApiKey string
|
||||
NotifyURL string // 任务进度更新回调地址
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ type SdTask struct {
|
||||
SessionId string `json:"session_id"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
|
||||
type PromptHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler {
|
||||
return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// Rewrite translate and rewrite prompt with ChatGPT
|
||||
func (h *PromptHandler) Rewrite(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt))
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
func (h *PromptHandler) Translate(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt))
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
@@ -133,6 +133,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdScaleAlg: data.HdScaleAlg,
|
||||
HdSteps: data.HdSteps,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
UserId: userId,
|
||||
Type: types.TaskImage.String(),
|
||||
@@ -153,7 +154,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskImage,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
@@ -371,13 +371,6 @@ func main() {
|
||||
group.GET("hits", h.Hits)
|
||||
}),
|
||||
|
||||
fx.Provide(handler.NewPromptHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||
group := s.Engine.Group("/api/prompt/")
|
||||
group.POST("rewrite", h.Rewrite)
|
||||
group.POST("translate", h.Translate)
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewFunctionHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||
group := s.Engine.Group("/api/admin/function/")
|
||||
|
||||
@@ -22,16 +22,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
|
||||
var apiURL string
|
||||
if config.CdnURL != "" {
|
||||
apiURL = config.CdnURL
|
||||
} else {
|
||||
apiURL = config.ApiURL
|
||||
}
|
||||
if config.Mode == "" {
|
||||
config.Mode = "fast"
|
||||
}
|
||||
return &Client{Config: config, apiURL: apiURL}
|
||||
return &Client{Config: config, apiURL: config.ApiURL}
|
||||
}
|
||||
|
||||
type ImageReq struct {
|
||||
@@ -81,6 +72,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
}
|
||||
|
||||
}
|
||||
logger.Info("API URL: ", apiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
@@ -90,9 +82,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
@@ -132,8 +122,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
@@ -183,8 +172,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
|
||||
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
|
||||
@@ -167,11 +167,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
job.Prompt = task.PromptEn
|
||||
if task.ImageUrl != "" {
|
||||
if s.Client.Config.CdnURL != "" {
|
||||
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
|
||||
} else {
|
||||
job.OrgURL = task.ImageUrl
|
||||
}
|
||||
job.OrgURL = task.ImageUrl
|
||||
}
|
||||
job.MessageId = task.Id
|
||||
tx := s.db.Updates(&job)
|
||||
|
||||
@@ -2,8 +2,11 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -62,6 +65,14 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 翻译提示词
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
|
||||
@@ -2,6 +2,7 @@ package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
@@ -46,6 +47,14 @@ func (s *Service) Run() {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
// 翻译提示词
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
||||
err = s.Txt2Img(task)
|
||||
if err != nil {
|
||||
@@ -66,7 +75,7 @@ func (s *Service) Run() {
|
||||
type Txt2ImgReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt"`
|
||||
Seed int64 `json:"seed"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Steps int `json:"steps"`
|
||||
CfgScale float32 `json:"cfg_scale"`
|
||||
Width int `json:"width"`
|
||||
|
||||
4
api/service/types.go
Normal file
4
api/service/types.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package service
|
||||
|
||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
@@ -3,15 +3,10 @@ package utils
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -66,64 +61,3 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||
|
||||
return imageBytes, nil
|
||||
}
|
||||
|
||||
type apiRes struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type apiErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var errRes apiErrRes
|
||||
client := req.C()
|
||||
if apiKey.ProxyURL != "" {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo-0125",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: messages,
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CalcTokens(text string, model string) (int, error) {
|
||||
@@ -18,3 +23,64 @@ func CalcTokens(text string, model string) (int, error) {
|
||||
token := tke.Encode(text, nil, nil)
|
||||
return len(token), nil
|
||||
}
|
||||
|
||||
type apiRes struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type apiErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var errRes apiErrRes
|
||||
client := req.C()
|
||||
if apiKey.ProxyURL != "" {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo-0125",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: messages,
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
@@ -94,6 +95,7 @@ func InterfaceToString(value interface{}) string {
|
||||
return JsonEncode(value)
|
||||
}
|
||||
|
||||
// CutWords 截取前 N 个单词
|
||||
func CutWords(str string, num int) string {
|
||||
// 按空格分割字符串为单词切片
|
||||
words := strings.Fields(str)
|
||||
@@ -105,3 +107,13 @@ func CutWords(str string, num int) string {
|
||||
return str
|
||||
}
|
||||
}
|
||||
|
||||
// HasChinese 判断文本是否含有中文
|
||||
func HasChinese(text string) bool {
|
||||
for _, char := range text {
|
||||
if unicode.Is(unicode.Scripts["Han"], char) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user