From 35fedbe817a3ddc14023981a5a5b3fbe4c96ca8a Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 15 Aug 2023 17:31:02 +0800 Subject: [PATCH] feat: midjourney image variation function is ready --- api/core/types/function.go | 6 +-- api/handler/mj_handler.go | 55 ++++++++++++++++++++++----- api/main.go | 1 + api/service/function/mid_journey.go | 27 ++++++++++++- web/src/components/ChatMidJourney.vue | 17 ++++++--- web/src/views/ChatPlus.vue | 14 ++++--- web/src/views/admin/Login.vue | 1 - 7 files changed, 95 insertions(+), 26 deletions(-) diff --git a/api/core/types/function.go b/api/core/types/function.go index cc92dcd9..020ab4bb 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -83,15 +83,15 @@ var InnerFunctions = []Function{ Properties: map[string]Property{ "prompt": { Type: "string", - Description: "绘画内容描述,提示词,此参数需要翻译成英文", + Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文", }, "ar": { Type: "string", - Description: "图片长宽比,如 16:9, --ar 3:2", + Description: "图片长宽比,如 --ar 4:3", }, "niji": { Type: "string", - Description: "动漫模型版本,如 --niji 5", + Description: "动漫模型版本,例如 --niji 5", }, }, Required: []string{}, diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index d7ba91c9..aad4309b 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -8,6 +8,7 @@ import ( "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" + "fmt" "github.com/gin-gonic/gin" "gorm.io/gorm" "time" @@ -83,7 +84,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task) if err != nil { logger.Error("error with get MidJourney task: ", err) - resp.ERROR(c, err.Error()) + resp.SUCCESS(c) return } @@ -138,16 +139,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { resp.SUCCESS(c, "SUCCESS") } +type reqVo struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` + SessionId string `json:"session_id"` + Key string `json:"key"` + Prompt string `json:"prompt"` +} + // Upscale send upscale command to MidJourney Bot func (h *MidJourneyHandler) Upscale(c *gin.Context) { - var data struct { - Index int32 `json:"index"` - MessageId string `json:"message_id"` - MessageHash string `json:"message_hash"` - SessionId string `json:"session_id"` - Key string `json:"key"` - } - + var data reqVo if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" || data.Key == "" { @@ -170,7 +173,39 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - utils.ReplyMessage(wsClient, "已推送放大图片任务到 MidJourney 机器人,请耐心等待任务执行...") + content := fmt.Sprintf("**%s** 已推送 Upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + utils.ReplyMessage(wsClient, content) + if h.App.MjTaskClients.Get(data.Key) == nil { + h.App.MjTaskClients.Put(data.Key, wsClient) + } + resp.SUCCESS(c) +} + +func (h *MidJourneyHandler) Variation(c *gin.Context) { + var data reqVo + if err := c.ShouldBindJSON(&data); err != nil || + data.SessionId == "" || + data.Key == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + wsClient := h.App.ChatClients.Get(data.SessionId) + if wsClient == nil { + resp.ERROR(c, "No Websocket client online") + return + } + + err := h.mjFunc.Variation(function.MjVariationReq{ + Index: data.Index, + MessageId: data.MessageId, + MessageHash: data.MessageHash, + }) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + content := fmt.Sprintf("**%s** 已推送 Variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + utils.ReplyMessage(wsClient, content) if h.App.MjTaskClients.Get(data.Key) == nil { h.App.MjTaskClients.Put(data.Key, wsClient) } diff --git a/api/main.go b/api/main.go index 5281c65f..2b8f8a3e 100644 --- a/api/main.go +++ b/api/main.go @@ -179,6 +179,7 @@ func main() { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { s.Engine.POST("/api/mj/notify", h.Notify) s.Engine.POST("/api/mj/upscale", h.Upscale) + s.Engine.POST("/api/mj/variation", h.Variation) }), // 管理后台控制器 diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index 47455db1..cf5455e4 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -81,7 +81,32 @@ func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error { if res.Code != types.Success { return errors.New(res.Message) } - + + return nil +} + +type MjVariationReq struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` +} + +func (f FuncMidJourney) Variation(upReq MjVariationReq) error { + url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL) + var res types.BizVo + r, err := f.client.R(). + SetHeader("Authorization", f.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(upReq). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + return nil } diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue index 593a1dd8..b64710c5 100644 --- a/web/src/components/ChatMidJourney.vue +++ b/web/src/components/ChatMidJourney.vue @@ -92,14 +92,23 @@ watch(() => props.content, (newVal) => { }); const emits = defineEmits(['disable-input', 'disable-input']); const upscale = (index) => { + send('/api/mj/upscale', index) +} + +const variation = (index) => { + send('/api/mj/variation', index) +} + +const send = (url, index) => { loading.value = true emits('disable-input') - httpPost("/api/mj/upscale", { + httpPost(url, { index: index, message_id: data.value?.["message_id"], message_hash: data.value?.["image"]?.hash, session_id: getSessionId(), - key: data.value?.["key"] + key: data.value?.["key"], + prompt: data.value?.["prompt"], }).then(() => { ElMessage.success("任务推送成功,请耐心等待任务执行...") loading.value = false @@ -108,10 +117,6 @@ const upscale = (index) => { emits('disable-input') }) } - -const variation = (index) => { - ElMessage.warning("当前版本暂未实现 Variation 功能!") -}