From fa341bab3030f0313a757cb7c4a62d6da77c454e Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 12 Sep 2023 18:01:24 +0800 Subject: [PATCH] feat: refactor MidJourney service for conpatible drawing in chat and draw in app --- api/core/types/chat.go | 11 -- api/handler/azure_handler.go | 23 +-- api/handler/mj_handler.go | 221 ++++++++++++++++---------- api/handler/openai_handler.go | 23 +-- api/main.go | 14 +- api/service/function/func_mj.go | 64 ++++++++ api/service/function/mid_journey.go | 129 --------------- api/service/mj_service.go | 189 ++++++++++++++++++++++ api/store/model/mj_job.go | 9 +- api/store/redis_queue.go | 41 +++++ database/update-3.1.3.sql | 4 + web/src/components/ChatMidJourney.vue | 1 - 12 files changed, 467 insertions(+), 262 deletions(-) create mode 100644 api/service/function/func_mj.go delete mode 100644 api/service/function/mid_journey.go create mode 100644 api/service/mj_service.go create mode 100644 api/store/redis_queue.go create mode 100644 database/update-3.1.3.sql diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 07e2e13f..f9f20f17 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -49,15 +49,6 @@ type ChatModel struct { Value string `json:"value"` } -type MjTask struct { - ChatId string - MessageId string - MessageHash string - UserId uint - RoleId uint - Icon string -} - type ApiError struct { Error struct { Message string @@ -77,5 +68,3 @@ var ModelToTokens = map[string]int{ "gpt-4": 8192, "gpt-4-32k": 32768, } - -const TaskStorePrefix = "/tasks/" diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go index 9238036c..a87f8e00 100644 --- a/api/handler/azure_handler.go +++ b/api/handler/azure_handler.go @@ -131,6 +131,13 @@ func (h *ChatHandler) sendAzureMessage( utils.ReplyMessage(ws, "![](/images/wx.png)") } else { f := h.App.Functions[functionName] + if functionName == types.FuncMidJourney { + params["user_id"] = userVo.Id + params["role_id"] = role.Id + params["chat_id"] = session.ChatId + params["icon"] = "/images/avatar/mid_journey.png" + params["session_id"] = session.SessionId + } data, err := f.Invoke(params) if err != nil { msg := "调用函数出错:" + err.Error() @@ -142,22 +149,8 @@ func (h *ChatHandler) sendAzureMessage( } else { content := data if functionName == types.FuncMidJourney { - key := utils.Sha256(data) - logger.Debug(data, ",", key) - // add task for MidJourney - h.App.MjTaskClients.Put(key, ws) - task := types.MjTask{ - UserId: userVo.Id, - RoleId: role.Id, - Icon: "/images/avatar/mid_journey.png", - ChatId: session.ChatId, - } - err := h.leveldb.Put(types.TaskStorePrefix+key, task) - if err != nil { - logger.Error("error with store MidJourney task: ", err) - } content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - + h.App.MjTaskClients.Put(session.SessionId, ws) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 7da7ba21..414157c9 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,16 +3,18 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service" "chatplus/service/function" "chatplus/service/oss" - "chatplus/store" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "encoding/base64" "fmt" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "gorm.io/gorm" + "net/http" "sync" "time" ) @@ -38,25 +40,26 @@ type Image struct { type MidJourneyHandler struct { BaseHandler - leveldb *store.LevelDB + redis *redis.Client db *gorm.DB - mjFunc function.FuncMidJourney + mjService *service.MjService uploaderManager *oss.UploaderManager lock sync.Mutex } func NewMidJourneyHandler( app *core.AppServer, - leveldb *store.LevelDB, + client *redis.Client, db *gorm.DB, manager *oss.UploaderManager, - functions map[string]function.Function) *MidJourneyHandler { + mjService *service.MjService) *MidJourneyHandler { h := MidJourneyHandler{ - leveldb: leveldb, + redis: client, db: db, uploaderManager: manager, lock: sync.Mutex{}, - mjFunc: functions[types.FuncMidJourney].(function.FuncMidJourney)} + mjService: mjService, + } h.App = app return &h } @@ -75,7 +78,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { Content string `json:"content"` Prompt string `json:"prompt"` Status TaskStatus `json:"status"` - Key string `json:"key"` + Progress int `json:"progress"` } if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { resp.ERROR(c, types.InvalidArgs) @@ -86,95 +89,142 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { h.lock.Lock() defer h.lock.Unlock() - // the job is saved - var job model.MidJourneyJob - res := h.db.Where("message_id = ?", data.MessageId).First(&job) - if res.Error == nil { - resp.SUCCESS(c) + taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() + if err != nil { + resp.SUCCESS(c) // 过期任务,丢弃 return } - data.Key = utils.Sha256(data.Prompt) - wsClient := h.App.MjTaskClients.Get(data.Key) - //logger.Info(data.Prompt, ",", key) - if data.Status == Finished { - var task types.MjTask - err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task) - if err != nil { - logger.Error("error with get MidJourney task: ", err) + var task service.MjTask + err = utils.JsonDecode(taskString, &task) + if err != nil { + resp.SUCCESS(c) // 非标准任务,丢弃 + return + } + + if task.Src == service.TaskSrcImg { // 绘画任务 + logger.Error(err) + var job model.MidJourneyJob + res := h.db.First(&job, task.Id) + if res.Error != nil { + resp.SUCCESS(c) // 非法任务,丢弃 + return + } + job.MessageId = data.MessageId + job.ReferenceId = data.ReferenceId + job.Progress = data.Progress + + // download image + if data.Progress == 100 { + imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) + if err != nil { + resp.ERROR(c, "error with download img: "+err.Error()) + return + } + job.ImgURL = imgURL + } else { + // 使用图片代理 + job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) + } + res = h.db.Updates(&job) + if res.Error != nil { + resp.ERROR(c, "error with update job: "+err.Error()) + return + } + + resp.SUCCESS(c) + + } else if task.Src == service.TaskSrcChat { // 聊天任务 + var job model.MidJourneyJob + res := h.db.Where("message_id = ?", data.MessageId).First(&job) + if res.Error == nil { resp.SUCCESS(c) return } - if wsClient != nil && data.ReferenceId != "" { - content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) - utils.ReplyMessage(wsClient, content) - } - // download image - imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) - if err != nil { - logger.Error("error with download image: ", err) + + wsClient := h.App.MjTaskClients.Get(task.Id) + if data.Status == Finished { if wsClient != nil && data.ReferenceId != "" { - content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) + content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) utils.ReplyMessage(wsClient, content) } - resp.ERROR(c, err.Error()) + // download image + imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) + if err != nil { + logger.Error("error with download image: ", err) + if wsClient != nil && data.ReferenceId != "" { + content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) + utils.ReplyMessage(wsClient, content) + } + resp.ERROR(c, err.Error()) + return + } + + data.Image.URL = imgURL + message := model.HistoryMessage{ + UserId: uint(task.UserId), + ChatId: task.ChatId, + RoleId: uint(task.RoleId), + Type: types.MjMsg, + Icon: task.Icon, + Content: utils.JsonEncode(data), + Tokens: 0, + UseContext: false, + } + res := h.db.Create(&message) + if res.Error != nil { + logger.Error("error with save chat history message: ", res.Error) + } + + // save the job + job.UserId = task.UserId + job.MessageId = data.MessageId + job.ReferenceId = data.ReferenceId + job.Prompt = data.Prompt + job.ImgURL = imgURL + job.Progress = data.Progress + job.CreatedAt = time.Now() + res = h.db.Create(&job) + if res.Error != nil { + logger.Error("error with save MidJourney Job: ", res.Error) + } + } + + if wsClient == nil { // 客户端断线,则丢弃 + logger.Errorf("Client is offline: %+v", data) + resp.SUCCESS(c, "Client is offline") return } - data.Image.URL = imgURL - message := model.HistoryMessage{ - UserId: task.UserId, - ChatId: task.ChatId, - RoleId: task.RoleId, - Type: types.MjMsg, - Icon: task.Icon, - Content: utils.JsonEncode(data), - Tokens: 0, - UseContext: false, - } - res := h.db.Create(&message) - if res.Error != nil { - logger.Error("error with save chat history message: ", res.Error) - } - - // save the job - job.UserId = task.UserId - job.ChatId = task.ChatId - job.MessageId = data.MessageId - job.ReferenceId = data.ReferenceId - job.Content = data.Content - job.Prompt = data.Prompt - job.Image = utils.JsonEncode(data.Image) - job.Hash = data.Image.Hash - job.CreatedAt = time.Now() - res = h.db.Create(&job) - if res.Error != nil { - logger.Error("error with save MidJourney Job: ", res.Error) + if data.Status == Finished { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) + // delete client + h.App.MjTaskClients.Delete(task.Id) + } else { + //// 使用代理临时转发图片 + //if data.Image.URL != "" { + // image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL) + // if err == nil { + // data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + // } + //} + data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) } + resp.SUCCESS(c, "SUCCESS") } - if wsClient == nil { // 客户端断线,则丢弃 - logger.Errorf("Client is offline: %+v", data) - resp.SUCCESS(c, "Client is offline") +} + +func (h *MidJourneyHandler) Proxy(c *gin.Context) { + url := c.Query("url") + image, err := utils.DownloadImage(url, h.App.Config.ProxyURL) + if err != nil { + c.String(http.StatusOK, err.Error()) return } - - if data.Status == Finished { - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) - // delete client - h.App.MjTaskClients.Delete(data.Key) - } else { - // 使用代理临时转发图片 - if data.Image.URL != "" { - image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL) - if err == nil { - data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) - } - resp.SUCCESS(c, "SUCCESS") + c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image)) } type reqVo struct { @@ -201,7 +251,12 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - err := h.mjFunc.Upscale(function.MjUpscaleReq{ + h.mjService.PushTask(service.MjTask{ + Index: data.Index, + MessageId: data.MessageId, + MessageHash: data.MessageHash, + }) + err := n.Upscale(function.MjUpscaleReq{ Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, @@ -211,7 +266,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - content := fmt.Sprintf("**%s** 已推送 Upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) utils.ReplyMessage(wsClient, content) if h.App.MjTaskClients.Get(data.Key) == nil { h.App.MjTaskClients.Put(data.Key, wsClient) @@ -242,7 +297,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { resp.ERROR(c, err.Error()) return } - content := fmt.Sprintf("**%s** 已推送 Variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) utils.ReplyMessage(wsClient, content) if h.App.MjTaskClients.Get(data.Key) == nil { h.App.MjTaskClients.Put(data.Key, wsClient) diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go index 73798b8b..dd7d7bac 100644 --- a/api/handler/openai_handler.go +++ b/api/handler/openai_handler.go @@ -131,6 +131,13 @@ func (h *ChatHandler) sendOpenAiMessage( utils.ReplyMessage(ws, "![](/images/wx.png)") } else { f := h.App.Functions[functionName] + if functionName == types.FuncMidJourney { + params["user_id"] = userVo.Id + params["role_id"] = role.Id + params["chat_id"] = session.ChatId + params["icon"] = "/images/avatar/mid_journey.png" + params["session_id"] = session.SessionId + } data, err := f.Invoke(params) if err != nil { msg := "调用函数出错:" + err.Error() @@ -142,22 +149,8 @@ func (h *ChatHandler) sendOpenAiMessage( } else { content := data if functionName == types.FuncMidJourney { - key := utils.Sha256(data) - logger.Debug(data, ",", key) - // add task for MidJourney - h.App.MjTaskClients.Put(key, ws) - task := types.MjTask{ - UserId: userVo.Id, - RoleId: role.Id, - Icon: "/images/avatar/mid_journey.png", - ChatId: session.ChatId, - } - err := h.leveldb.Put(types.TaskStorePrefix+key, task) - if err != nil { - logger.Error("error with store MidJourney task: ", err) - } content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - + h.App.MjTaskClients.Put(session.SessionId, ws) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/main.go b/api/main.go index 2860392f..f04cccd5 100644 --- a/api/main.go +++ b/api/main.go @@ -135,6 +135,12 @@ func main() { return service.NewCaptchaService(config.ApiConfig) }), fx.Provide(oss.NewUploaderManager), + fx.Provide(service.NewMjService), + fx.Provide(func(mjService *service.MjService) { + go func() { + mjService.Run() + }() + }), // 注册路由 fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { @@ -183,9 +189,11 @@ func main() { group.POST("verify", h.Verify) }), fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { - s.Engine.POST("/api/mj/notify", h.Notify) - s.Engine.POST("/api/mj/upscale", h.Upscale) - s.Engine.POST("/api/mj/variation", h.Variation) + group := s.Engine.Group("/api/mj/") + group.POST("notify", h.Notify) + group.POST("upscale", h.Upscale) + group.POST("variation", h.Variation) + group.GET("proxy", h.Proxy) }), // 管理后台控制器 diff --git a/api/service/function/func_mj.go b/api/service/function/func_mj.go new file mode 100644 index 00000000..79cc7313 --- /dev/null +++ b/api/service/function/func_mj.go @@ -0,0 +1,64 @@ +package function + +import ( + "chatplus/service" + "chatplus/utils" + "fmt" +) + +// AI 绘画函数 + +type FuncMidJourney struct { + name string + service *service.MjService +} + +func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney { + return FuncMidJourney{ + name: "MidJourney AI 绘画", + service: mjService} +} + +func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { + logger.Infof("MJ 绘画参数:%+v", params) + prompt := utils.InterfaceToString(params["prompt"]) + if !utils.IsEmptyValue(params["--ar"]) { + prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"]) + delete(params, "--ar") + } + if !utils.IsEmptyValue(params["--s"]) { + prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"]) + delete(params, "--s") + } + if !utils.IsEmptyValue(params["--seed"]) { + prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"]) + delete(params, "--seed") + } + if !utils.IsEmptyValue(params["--no"]) { + prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"]) + delete(params, "--no") + } + if !utils.IsEmptyValue(params["--niji"]) { + prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"]) + delete(params, "--niji") + } else { + prompt = prompt + " --v 5.2" + } + + f.service.PushTask(service.MjTask{ + Id: utils.InterfaceToString(params["session_id"]), + Src: service.TaskSrcChat, + Prompt: prompt, + UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), + RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), + Icon: utils.InterfaceToString(params["icon"]), + ChatId: utils.InterfaceToString(params["chat_id"]), + }) + return prompt, nil +} + +func (f FuncMidJourney) Name() string { + return f.name +} + +var _ Function = &FuncMidJourney{} diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go deleted file mode 100644 index 237ec08c..00000000 --- a/api/service/function/mid_journey.go +++ /dev/null @@ -1,129 +0,0 @@ -package function - -import ( - "chatplus/core/types" - "chatplus/utils" - "errors" - "fmt" - "github.com/imroc/req/v3" - "time" -) - -// AI 绘画函数 - -type FuncMidJourney struct { - name string - config types.ChatPlusExtConfig - client *req.Client -} - -func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney { - return FuncMidJourney{ - name: "MidJourney AI 绘画", - config: config, - client: req.C().SetTimeout(30 * time.Second)} -} - -func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { - if f.config.Token == "" { - return "", errors.New("无效的 API Token") - } - - logger.Infof("MJ 绘画参数:%+v", params) - prompt := utils.InterfaceToString(params["prompt"]) - if !utils.IsEmptyValue(params["--ar"]) { - prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"]) - delete(params, "--ar") - } - if !utils.IsEmptyValue(params["--s"]) { - prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"]) - delete(params, "--s") - } - if !utils.IsEmptyValue(params["--seed"]) { - prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"]) - delete(params, "--seed") - } - if !utils.IsEmptyValue(params["--no"]) { - prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"]) - delete(params, "--no") - } - if !utils.IsEmptyValue(params["--niji"]) { - prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"]) - delete(params, "--niji") - } else { - prompt = prompt + " --v 5.2" - } - params["prompt"] = prompt - url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL) - var res types.BizVo - r, err := f.client.R(). - SetHeader("Authorization", f.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(params). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return "", fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return "", errors.New(res.Message) - } - - return prompt, nil -} - -type MjUpscaleReq struct { - Index int32 `json:"index"` - MessageId string `json:"message_id"` - MessageHash string `json:"message_hash"` -} - -func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error { - url := fmt.Sprintf("%s/api/mj/upscale", f.config.ApiURL) - var res types.BizVo - r, err := f.client.R(). - SetHeader("Authorization", f.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(upReq). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} - -type MjVariationReq struct { - Index int32 `json:"index"` - MessageId string `json:"message_id"` - MessageHash string `json:"message_hash"` -} - -func (f FuncMidJourney) Variation(upReq MjVariationReq) error { - url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL) - var res types.BizVo - r, err := f.client.R(). - SetHeader("Authorization", f.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(upReq). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} - -func (f FuncMidJourney) Name() string { - return f.name -} - -var _ Function = &FuncMidJourney{} diff --git a/api/service/mj_service.go b/api/service/mj_service.go new file mode 100644 index 00000000..ccd70ed9 --- /dev/null +++ b/api/service/mj_service.go @@ -0,0 +1,189 @@ +package service + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "chatplus/store" + "chatplus/utils" + "context" + "errors" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/imroc/req/v3" + "time" +) + +var logger = logger2.GetLogger() + +// MJ 绘画服务 + +const MjRunningJobKey = "MidJourney_Running_Job" + +type TaskType string + +const ( + Image = TaskType("image") + Upscale = TaskType("upscale") + Variation = TaskType("variation") +) + +type TaskSrc string + +const ( + TaskSrcChat = TaskSrc("chat") + TaskSrcImg = TaskSrc("img") +) + +type MjTask struct { + Id string `json:"id"` + Src TaskSrc `json:"src"` + Type TaskType `json:"type"` + UserId int `json:"user_id"` + Prompt string `json:"prompt,omitempty"` + ChatId string `json:"chat_id,omitempty"` + RoleId int `json:"role_id,omitempty"` + Icon string `json:"icon,omitempty"` + Index int32 `json:"index,omitempty"` + MessageId string `json:"message_id,omitempty"` + MessageHash string `json:"message_hash,omitempty"` + RetryCount int `json:"retry_count"` +} + +type MjService struct { + config types.ChatPlusExtConfig + client *req.Client + taskQueue *store.RedisQueue + redis *redis.Client +} + +func NewMjService(config types.ChatPlusExtConfig, client *redis.Client) *MjService { + return &MjService{ + config: config, + redis: client, + taskQueue: store.NewRedisQueue("midjourney_task_queue", client), + client: req.C().SetTimeout(30 * time.Second)} +} + +func (s *MjService) Run() { + ctx := context.Background() + for { + _, err := s.redis.Get(ctx, MjRunningJobKey).Result() + if err == nil { // a task is running, waiting for finish + time.Sleep(time.Second * 3) + continue + } + var task MjTask + err = s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + switch task.Type { + case Image: + err = s.image(task.Prompt) + break + case Upscale: + err = s.upscale(MjUpscaleReq{ + Index: task.Index, + MessageId: task.MessageId, + MessageHash: task.MessageHash, + }) + break + case Variation: + err = s.variation(MjVariationReq{ + Index: task.Index, + MessageId: task.MessageId, + MessageHash: task.MessageHash, + }) + } + if err != nil { + if task.RetryCount > 5 { + continue + } + task.RetryCount += 1 + time.Sleep(time.Second) + s.taskQueue.RPush(task) + // TODO: 执行失败通知聊天客户端 + continue + } + + // 锁定任务执行通道,直到任务超时(10分钟) + s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600) + } +} + +func (s *MjService) PushTask(task MjTask) { + s.taskQueue.RPush(task) +} + +func (s *MjService) image(prompt string) error { + logger.Infof("MJ 绘画参数:%+v", prompt) + body := map[string]string{"prompt": prompt} + url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("Authorization", s.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(body). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + + return nil +} + +type MjUpscaleReq struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` +} + +func (s *MjService) upscale(upReq MjUpscaleReq) error { + url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("Authorization", s.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(upReq). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + + return nil +} + +type MjVariationReq struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` +} + +func (s *MjService) variation(upReq MjVariationReq) error { + url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("Authorization", s.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(upReq). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + + return nil +} diff --git a/api/store/model/mj_job.go b/api/store/model/mj_job.go index 3162273a..30ac00ca 100644 --- a/api/store/model/mj_job.go +++ b/api/store/model/mj_job.go @@ -4,14 +4,13 @@ import "time" type MidJourneyJob struct { Id uint `gorm:"primarykey;column:id"` - UserId uint - ChatId string + UserId int MessageId string ReferenceId string - Hash string - Content string + ImgURL string + Hash string // message hash + Progress int Prompt string - Image string CreatedAt time.Time } diff --git a/api/store/redis_queue.go b/api/store/redis_queue.go new file mode 100644 index 00000000..71a730a6 --- /dev/null +++ b/api/store/redis_queue.go @@ -0,0 +1,41 @@ +package store + +import ( + "chatplus/utils" + "context" + "github.com/go-redis/redis/v8" +) + +type RedisQueue struct { + name string + client *redis.Client + ctx context.Context +} + +func NewRedisQueue(name string, client *redis.Client) *RedisQueue { + return &RedisQueue{name: name, client: client, ctx: context.Background()} +} + +func (q *RedisQueue) RPush(value interface{}) { + q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)) +} + +func (q *RedisQueue) LPush(value interface{}) { + q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)) +} + +func (q *RedisQueue) LPop(value interface{}) error { + result, err := q.client.BLPop(q.ctx, 0, q.name).Result() + if err != nil { + return err + } + return utils.JsonDecode(result[1], value) +} + +func (q *RedisQueue) RPop(value interface{}) error { + result, err := q.client.BRPop(q.ctx, 0, q.name).Result() + if err != nil { + return err + } + return utils.JsonDecode(result[1], value) +} diff --git a/database/update-3.1.3.sql b/database/update-3.1.3.sql new file mode 100644 index 00000000..9871fcfa --- /dev/null +++ b/database/update-3.1.3.sql @@ -0,0 +1,4 @@ +ALTER TABLE `chatgpt_mj_jobs` DROP `image`; +ALTER TABLE `chatgpt_mj_jobs` ADD `progress` SMALLINT(5) NULL DEFAULT '0' COMMENT '任务进度' AFTER `prompt`; +ALTER TABLE `chatgpt_mj_jobs` ADD `hash` VARCHAR(100) NULL DEFAULT NULL COMMENT 'message hash' AFTER `prompt`; +ALTER TABLE `chatgpt_mj_jobs` ADD `img_url` VARCHAR(255) NULL DEFAULT NULL COMMENT '图片URL' AFTER `prompt`; \ No newline at end of file diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue index 1de544ea..a62e03ca 100644 --- a/web/src/components/ChatMidJourney.vue +++ b/web/src/components/ChatMidJourney.vue @@ -109,7 +109,6 @@ const send = (url, index) => { message_id: data.value?.["message_id"], message_hash: data.value?.["image"]?.hash, session_id: getSessionId(), - key: data.value?.["key"], prompt: data.value?.["prompt"], }).then(() => { ElMessage.success("任务推送成功,请耐心等待任务执行...")