feat: auto translate and rewrite prompt for midjourney and stable-diffusion

This commit is contained in:
RockYang
2024-03-27 13:45:52 +08:00
parent 342b76f666
commit b5947545cb
18 changed files with 162 additions and 355 deletions

View File

@@ -22,16 +22,7 @@ type Client struct {
}
func NewClient(config types.MidJourneyPlusConfig) *Client {
var apiURL string
if config.CdnURL != "" {
apiURL = config.CdnURL
} else {
apiURL = config.ApiURL
}
if config.Mode == "" {
config.Mode = "fast"
}
return &Client{Config: config, apiURL: apiURL}
return &Client{Config: config, apiURL: config.ApiURL}
}
type ImageReq struct {
@@ -81,6 +72,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
@@ -90,9 +82,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -132,8 +122,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
@@ -183,8 +172,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {

View File

@@ -167,11 +167,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
if s.Client.Config.CdnURL != "" {
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
} else {
job.OrgURL = task.ImageUrl
}
job.OrgURL = task.ImageUrl
}
job.MessageId = task.Id
tx := s.db.Updates(&job)

View File

@@ -2,8 +2,11 @@ package mj
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
@@ -62,6 +65,14 @@ func (s *Service) Run() {
continue
}
// 翻译提示词
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content
}
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:

View File

@@ -2,6 +2,7 @@ package sd
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
@@ -46,6 +47,14 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
// 翻译提示词
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
@@ -66,7 +75,7 @@ func (s *Service) Run() {
type Txt2ImgReq struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Seed int64 `json:"seed"`
Seed int64 `json:"seed,omitempty"`
Steps int `json:"steps"`
CfgScale float32 `json:"cfg_scale"`
Width int `json:"width"`

4
api/service/types.go Normal file
View File

@@ -0,0 +1,4 @@
package service
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"