From 2e13ddf405e11c87fc82e4d3a9c76c6318bd625c Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 13 Sep 2023 15:50:00 +0800 Subject: [PATCH] feat: optimize mj notidy api, use job queue to send ai drawing request --- api/handler/mj_handler.go | 54 ++++++++++++++++++++++----------------- api/service/mj_service.go | 6 ++--- api/test/test.go | 7 ++--- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index d90f5e4f..7d668cc4 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { h.lock.Lock() defer h.lock.Unlock() - err := h.notifyHandler(c, data) + err, finished := h.notifyHandler(c, data) if err != nil { resp.ERROR(c, err.Error()) return } // 解除任务锁定 - if data.Status == Finished || data.Status == Stopped { + if finished && (data.Status == Finished || data.Status == Stopped) { h.redis.Del(c, service.MjRunningJobKey) } resp.SUCCESS(c) } -func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error { +func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) { taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() if err != nil { // 过期任务,丢弃 logger.Warn("任务已过期:", err) - return nil + return nil, true } var task service.MjTask err = utils.JsonDecode(taskString, &task) if err != nil { // 非标准任务,丢弃 logger.Warn("任务解析失败:", err) - return nil + return nil, false } if task.Src == service.TaskSrcImg { // 绘画任务 @@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error res := h.db.First(&job, task.Id) if res.Error != nil { logger.Warn("非法任务:", err) - return nil + return nil, false } job.MessageId = data.MessageId job.ReferenceId = data.ReferenceId job.Progress = data.Progress job.Prompt = data.Prompt - // download image + // 任务完成,将最终的图片下载下来 if data.Progress == 100 { imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) if err != nil { logger.Error("error with download img: ", err.Error()) - return err + return err, false } job.ImgURL = imgURL } else { - // 使用图片代理 - job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) + // 临时图片直接保存,访问的时候使用代理进行转发 + job.ImgURL = data.Image.URL } res = h.db.Updates(&job) if res.Error != nil { logger.Error("error with update job: ", err.Error()) - return res.Error + return res.Error, false } } else if task.Src == service.TaskSrcChat { // 聊天任务 @@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error res := h.db.Where("message_id = ?", data.MessageId).First(&job) if res.Error == nil { logger.Warn("重复消息:", data.MessageId) - return nil + return nil, false } wsClient := h.App.MjTaskClients.Get(task.Id) @@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) utils.ReplyMessage(wsClient, content) } - return err + return err, false } tx := h.db.Begin() @@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error } res = tx.Create(&message) if res.Error != nil { - return res.Error + return res.Error, false } // save the job @@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error job.Prompt = data.Prompt job.ImgURL = imgURL job.Progress = data.Progress + job.Hash = data.Image.Hash job.CreatedAt = time.Now() res = tx.Create(&job) if res.Error != nil { tx.Rollback() - return res.Error + return res.Error, false } tx.Commit() } if wsClient == nil { // 客户端断线,则丢弃 logger.Errorf("Client is offline: %+v", data) - return nil + return nil, true } if data.Status == Finished { utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) - // delete client + // 本次绘画完毕,移除客户端 h.App.MjTaskClients.Delete(task.Id) } else { - data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) + // 使用代理临时转发图片 + if data.Image.URL != "" { + image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL) + if err == nil { + data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) } } - return nil + return nil, true } +// Proxy 通过代理访问 discord f服务器图片 func (h *MidJourneyHandler) Proxy(c *gin.Context) { - logger.Info(c.Request.Host, c.Request.Proto) - return - url := c.Query("url") - image, err := utils.DownloadImage(url, h.App.Config.ProxyURL) + imgURL := c.Query("url") + imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL) if err != nil { c.String(http.StatusOK, err.Error()) return } - c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image)) + c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) } type reqVo struct { diff --git a/api/service/mj_service.go b/api/service/mj_service.go index b0295168..b88c27bb 100644 --- a/api/service/mj_service.go +++ b/api/service/mj_service.go @@ -68,9 +68,8 @@ func (s *MjService) Run() { logger.Info("Starting MidJourney job consumer.") ctx := context.Background() for { - t, err := s.redis.Get(ctx, MjRunningJobKey).Result() + _, err := s.redis.Get(ctx, MjRunningJobKey).Result() if err == nil { - logger.Infof("An task is not finished: %s", t) time.Sleep(time.Second * 3) continue } @@ -107,11 +106,12 @@ func (s *MjService) Run() { task.RetryCount += 1 s.taskQueue.RPush(task) // TODO: 执行失败通知聊天客户端 + time.Sleep(time.Second * 3) continue } // 锁定任务执行通道,直到任务超时(10分钟) - s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600) + s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*10) } } diff --git a/api/test/test.go b/api/test/test.go index 5d5bdf6f..87b913f3 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -2,11 +2,12 @@ package main import ( "fmt" - "path/filepath" + "net/url" ) func main() { - imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png" + parse, _ := url.Parse("http://localhost:5678/static") - fmt.Println(filepath.Ext(filepath.Base(imageURL))) + imgURLPrefix := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host) + fmt.Println(imgURLPrefix) }