From b5947545cb4f233e0d6f14cc49d1c7a34228bf9f Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 27 Mar 2024 13:45:52 +0800 Subject: [PATCH] feat: auto translate and rewrite prompt for midjourney and stable-diffusion --- api/core/types/config.go | 1 - api/core/types/task.go | 1 - api/handler/prompt_handler.go | 60 ---------------- api/handler/sd_handler.go | 2 +- api/main.go | 7 -- api/service/mj/plus/client.go | 22 ++---- api/service/mj/plus/service.go | 6 +- api/service/mj/service.go | 11 +++ api/service/sd/service.go | 11 ++- api/service/types.go | 4 ++ api/utils/net.go | 66 ----------------- api/utils/openai.go | 66 +++++++++++++++++ api/utils/strings.go | 12 ++++ web/src/views/ChatPlus.vue | 9 ++- web/src/views/ImageMj.vue | 84 +++------------------- web/src/views/ImageSd.vue | 103 +++++---------------------- web/src/views/mobile/ChatSession.vue | 9 ++- web/src/views/mobile/ImageMj.vue | 43 ++++------- 18 files changed, 162 insertions(+), 355 deletions(-) delete mode 100644 api/handler/prompt_handler.go create mode 100644 api/service/types.go diff --git a/api/core/types/config.go b/api/core/types/config.go index 39990fb5..612d7ddc 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -66,7 +66,6 @@ type MidJourneyPlusConfig struct { Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 ApiURL string // api 地址 Mode string // 绘画模式,可选值:fast/turbo/relax - CdnURL string // CDN 加速地址 ApiKey string NotifyURL string // 任务进度更新回调地址 } diff --git a/api/core/types/task.go b/api/core/types/task.go index 7e84aa65..0120c739 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -36,7 +36,6 @@ type SdTask struct { SessionId string `json:"session_id"` Type TaskType `json:"type"` UserId int `json:"user_id"` - Prompt string `json:"prompt,omitempty"` Params SdTaskParams `json:"params"` RetryCount int `json:"retry_count"` } diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go deleted file mode 100644 index 7e82ff19..00000000 --- a/api/handler/prompt_handler.go +++ /dev/null @@ -1,60 +0,0 @@ -package handler - -import ( - "chatplus/core" - "chatplus/core/types" - "chatplus/utils" - "chatplus/utils/resp" - "fmt" - - "github.com/gin-gonic/gin" - "gorm.io/gorm" -) - -const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" -const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" - -type PromptHandler struct { - BaseHandler -} - -func NewPromptHandler(app *core.AppServer, db *gorm.DB) *PromptHandler { - return &PromptHandler{BaseHandler: BaseHandler{App: app, DB: db}} -} - -// Rewrite translate and rewrite prompt with ChatGPT -func (h *PromptHandler) Rewrite(c *gin.Context) { - var data struct { - Prompt string `json:"prompt"` - } - if err := c.ShouldBindJSON(&data); err != nil { - resp.ERROR(c, types.InvalidArgs) - return - } - - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(rewritePromptTemplate, data.Prompt)) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - - resp.SUCCESS(c, content) -} - -func (h *PromptHandler) Translate(c *gin.Context) { - var data struct { - Prompt string `json:"prompt"` - } - if err := c.ShouldBindJSON(&data); err != nil { - resp.ERROR(c, types.InvalidArgs) - return - } - - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, data.Prompt)) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - - resp.SUCCESS(c, content) -} diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 0f0320d4..9881a35d 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -133,6 +133,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { HdScaleAlg: data.HdScaleAlg, HdSteps: data.HdSteps, } + job := model.SdJob{ UserId: userId, Type: types.TaskImage.String(), @@ -153,7 +154,6 @@ func (h *SdJobHandler) Image(c *gin.Context) { Id: int(job.Id), SessionId: data.SessionId, Type: types.TaskImage, - Prompt: data.Prompt, Params: params, UserId: userId, }) diff --git a/api/main.go b/api/main.go index 233f38e9..d070f341 100644 --- a/api/main.go +++ b/api/main.go @@ -371,13 +371,6 @@ func main() { group.GET("hits", h.Hits) }), - fx.Provide(handler.NewPromptHandler), - fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { - group := s.Engine.Group("/api/prompt/") - group.POST("rewrite", h.Rewrite) - group.POST("translate", h.Translate) - }), - fx.Provide(admin.NewFunctionHandler), fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { group := s.Engine.Group("/api/admin/function/") diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go index fe525345..757ebc96 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus/client.go @@ -22,16 +22,7 @@ type Client struct { } func NewClient(config types.MidJourneyPlusConfig) *Client { - var apiURL string - if config.CdnURL != "" { - apiURL = config.CdnURL - } else { - apiURL = config.ApiURL - } - if config.Mode == "" { - config.Mode = "fast" - } - return &Client{Config: config, apiURL: apiURL} + return &Client{Config: config, apiURL: config.ApiURL} } type ImageReq struct { @@ -81,6 +72,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { } } + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := req.C().R(). @@ -90,9 +82,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { SetErrorResult(&errRes). Post(apiURL) if err != nil { - errStr, _ := io.ReadAll(r.Body) - logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL) - return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) } if r.IsErrorState() { @@ -132,8 +122,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) { SetErrorResult(&errRes). Post(apiURL) if err != nil { - errStr, _ := io.ReadAll(r.Body) - return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) } if r.IsErrorState() { @@ -183,8 +172,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { SetErrorResult(&errRes). Post(apiURL) if err != nil { - errStr, _ := io.ReadAll(r.Body) - return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) } if r.IsErrorState() { diff --git a/api/service/mj/plus/service.go b/api/service/mj/plus/service.go index f03fd679..f02db6cb 100644 --- a/api/service/mj/plus/service.go +++ b/api/service/mj/plus/service.go @@ -167,11 +167,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error { job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) job.Prompt = task.PromptEn if task.ImageUrl != "" { - if s.Client.Config.CdnURL != "" { - job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1) - } else { - job.OrgURL = task.ImageUrl - } + job.OrgURL = task.ImageUrl } job.MessageId = task.Id tx := s.db.Updates(&job) diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 0cc686e0..154f70db 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -2,8 +2,11 @@ package mj import ( "chatplus/core/types" + "chatplus/service" "chatplus/store" "chatplus/store/model" + "chatplus/utils" + "fmt" "strings" "sync/atomic" "time" @@ -62,6 +65,14 @@ func (s *Service) Run() { continue } + // 翻译提示词 + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt)) + if err == nil { + task.Prompt = content + } + } + logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) switch task.Type { case types.TaskImage: diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 6cae7b0a..ccb56974 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -2,6 +2,7 @@ package sd import ( "chatplus/core/types" + "chatplus/service" "chatplus/service/oss" "chatplus/store" "chatplus/store/model" @@ -46,6 +47,14 @@ func (s *Service) Run() { logger.Errorf("taking task with error: %v", err) continue } + // 翻译提示词 + if utils.HasChinese(task.Params.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt)) + if err == nil { + task.Params.Prompt = content + } + } + logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) err = s.Txt2Img(task) if err != nil { @@ -66,7 +75,7 @@ func (s *Service) Run() { type Txt2ImgReq struct { Prompt string `json:"prompt"` NegativePrompt string `json:"negative_prompt"` - Seed int64 `json:"seed"` + Seed int64 `json:"seed,omitempty"` Steps int `json:"steps"` CfgScale float32 `json:"cfg_scale"` Width int `json:"width"` diff --git a/api/service/types.go b/api/service/types.go new file mode 100644 index 00000000..9a8a0d00 --- /dev/null +++ b/api/service/types.go @@ -0,0 +1,4 @@ +package service + +const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" +const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/api/utils/net.go b/api/utils/net.go index 578d42bf..74d05e03 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -3,15 +3,10 @@ package utils import ( "chatplus/core/types" logger2 "chatplus/logger" - "chatplus/store/model" "encoding/json" - "fmt" - "github.com/imroc/req/v3" - "gorm.io/gorm" "io" "net/http" "net/url" - "time" ) var logger = logger2.GetLogger() @@ -66,64 +61,3 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) { return imageBytes, nil } - -type apiRes struct { - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - -type apiErrRes struct { - Error struct { - Code interface{} `json:"code"` - Message string `json:"message"` - Param interface{} `json:"param"` - Type string `json:"type"` - } `json:"error"` -} - -func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { - var apiKey model.ApiKey - res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey) - if res.Error != nil { - return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) - } - - messages := make([]interface{}, 1) - messages[0] = types.Message{ - Role: "user", - Content: prompt, - } - - var response apiRes - var errRes apiErrRes - client := req.C() - if apiKey.ProxyURL != "" { - client.SetProxyURL(apiKey.ApiURL) - } - r, err := client.R().SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(types.ApiRequest{ - Model: "gpt-3.5-turbo-0125", - Temperature: 0.9, - MaxTokens: 1024, - Stream: false, - Messages: messages, - }). - SetErrorResult(&errRes). - SetSuccessResult(&response).Post(apiKey.ApiURL) - if err != nil || r.IsErrorState() { - return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) - } - - // 更新 API KEY 的最后使用时间 - db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - - return response.Choices[0].Message.Content, nil -} diff --git a/api/utils/openai.go b/api/utils/openai.go index 661c31ff..9842e07e 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -1,8 +1,13 @@ package utils import ( + "chatplus/core/types" + "chatplus/store/model" "fmt" + "github.com/imroc/req/v3" "github.com/pkoukk/tiktoken-go" + "gorm.io/gorm" + "time" ) func CalcTokens(text string, model string) (int, error) { @@ -18,3 +23,64 @@ func CalcTokens(text string, model string) (int, error) { token := tke.Encode(text, nil, nil) return len(token), nil } + +type apiRes struct { + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +type apiErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { + var apiKey model.ApiKey + res := db.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).First(&apiKey) + if res.Error != nil { + return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) + } + + messages := make([]interface{}, 1) + messages[0] = types.Message{ + Role: "user", + Content: prompt, + } + + var response apiRes + var errRes apiErrRes + client := req.C() + if apiKey.ProxyURL != "" { + client.SetProxyURL(apiKey.ApiURL) + } + r, err := client.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(types.ApiRequest{ + Model: "gpt-3.5-turbo-0125", + Temperature: 0.9, + MaxTokens: 1024, + Stream: false, + Messages: messages, + }). + SetErrorResult(&errRes). + SetSuccessResult(&response).Post(apiKey.ApiURL) + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) + } + + // 更新 API KEY 的最后使用时间 + db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + return response.Choices[0].Message.Content, nil +} diff --git a/api/utils/strings.go b/api/utils/strings.go index 6268ffee..6a9e7655 100644 --- a/api/utils/strings.go +++ b/api/utils/strings.go @@ -6,6 +6,7 @@ import ( "math/rand" "strings" "time" + "unicode" "golang.org/x/crypto/sha3" ) @@ -94,6 +95,7 @@ func InterfaceToString(value interface{}) string { return JsonEncode(value) } +// CutWords 截取前 N 个单词 func CutWords(str string, num int) string { // 按空格分割字符串为单词切片 words := strings.Fields(str) @@ -105,3 +107,13 @@ func CutWords(str string, num int) string { return str } } + +// HasChinese 判断文本是否含有中文 +func HasChinese(text string) bool { + for _, char := range text { + if unicode.Is(unicode.Scripts["Han"], char) { + return true + } + } + return false +} diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 4359bc66..19a729a6 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -244,7 +244,7 @@