feat: auto translate and rewrite prompt for midjourney and stable-diffusion

This commit is contained in:
RockYang
2024-03-27 13:45:52 +08:00
parent b60a639312
commit 9794d67eaa
18 changed files with 162 additions and 355 deletions

View File

@@ -22,16 +22,7 @@ type Client struct {
}
func NewClient(config types.MidJourneyPlusConfig) *Client {
var apiURL string
if config.CdnURL != "" {
apiURL = config.CdnURL
} else {
apiURL = config.ApiURL
}
if config.Mode == "" {
config.Mode = "fast"
}
return &Client{Config: config, apiURL: apiURL}
return &Client{Config: config, apiURL: config.ApiURL}
}
type ImageReq struct {
@@ -81,6 +72,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
@@ -90,9 +82,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -132,8 +122,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -183,8 +172,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {

View File

@@ -167,11 +167,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
if s.Client.Config.CdnURL != "" {
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
} else {
job.OrgURL = task.ImageUrl
}
job.OrgURL = task.ImageUrl
}
job.MessageId = task.Id
tx := s.db.Updates(&job)

View File

@@ -2,8 +2,11 @@ package mj
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
@@ -62,6 +65,14 @@ func (s *Service) Run() {
continue
}
// 翻译提示词
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
}
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage: