mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: auto translate and rewrite prompt for midjourney and stable-diffusion
This commit is contained in:
		@@ -22,16 +22,7 @@ type Client struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	if config.CdnURL != "" {
 | 
			
		||||
		apiURL = config.CdnURL
 | 
			
		||||
	} else {
 | 
			
		||||
		apiURL = config.ApiURL
 | 
			
		||||
	}
 | 
			
		||||
	if config.Mode == "" {
 | 
			
		||||
		config.Mode = "fast"
 | 
			
		||||
	}
 | 
			
		||||
	return &Client{Config: config, apiURL: apiURL}
 | 
			
		||||
	return &Client{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageReq struct {
 | 
			
		||||
@@ -81,6 +72,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
@@ -90,9 +82,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
@@ -132,8 +122,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
@@ -183,8 +172,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
 
 | 
			
		||||
@@ -167,11 +167,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
	job.Prompt = task.PromptEn
 | 
			
		||||
	if task.ImageUrl != "" {
 | 
			
		||||
		if s.Client.Config.CdnURL != "" {
 | 
			
		||||
			job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
 | 
			
		||||
		} else {
 | 
			
		||||
			job.OrgURL = task.ImageUrl
 | 
			
		||||
		}
 | 
			
		||||
		job.OrgURL = task.ImageUrl
 | 
			
		||||
	}
 | 
			
		||||
	job.MessageId = task.Id
 | 
			
		||||
	tx := s.db.Updates(&job)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,8 +2,11 @@ package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -62,6 +65,14 @@ func (s *Service) Run() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 翻译提示词
 | 
			
		||||
		if utils.HasChinese(task.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Prompt = content
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package sd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
@@ -46,6 +47,14 @@ func (s *Service) Run() {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// 翻译提示词
 | 
			
		||||
		if utils.HasChinese(task.Params.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Params.Prompt = content
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
 | 
			
		||||
		err = s.Txt2Img(task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -66,7 +75,7 @@ func (s *Service) Run() {
 | 
			
		||||
type Txt2ImgReq struct {
 | 
			
		||||
	Prompt            string  `json:"prompt"`
 | 
			
		||||
	NegativePrompt    string  `json:"negative_prompt"`
 | 
			
		||||
	Seed              int64   `json:"seed"`
 | 
			
		||||
	Seed              int64   `json:"seed,omitempty"`
 | 
			
		||||
	Steps             int     `json:"steps"`
 | 
			
		||||
	CfgScale          float32 `json:"cfg_scale"`
 | 
			
		||||
	Width             int     `json:"width"`
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										4
									
								
								api/service/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								api/service/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
 | 
			
		||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
 | 
			
		||||
		Reference in New Issue
	
	Block a user