refactor: add midjourney pool implementation, add translate prompt for mj drawing

This commit is contained in:
RockYang
2023-12-13 16:38:27 +08:00
parent 8f4d20e411
commit 6d71f24f75
16 changed files with 226 additions and 272 deletions

View File

@@ -13,7 +13,8 @@ import (
"gorm.io/gorm"
)
const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct {
BaseHandler
@@ -47,6 +48,25 @@ type apiErrRes struct {
} `json:"error"`
}
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, rewritePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
@@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, translatePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY
var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
if res.Error != nil {
resp.ERROR(c, "找不到可用 OpenAI API KEY")
return
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}
messages := make([]interface{}, 1)
messages[0] = types.Message{
Role: "user",
Content: fmt.Sprintf(translatePromptTemplate, data.Prompt),
Content: fmt.Sprintf(promptTemplate, prompt),
}
var response apiRes
@@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) {
SetErrorResult(&errRes).
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message))
return
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
}
resp.SUCCESS(c, response.Choices[0].Message.Content)
return response.Choices[0].Message.Content, nil
}