feat: optimize mj notidy api, use job queue to send ai drawing request

This commit is contained in:
RockYang 2023-09-13 15:50:00 +08:00
parent 1d3acc8ed3
commit 2e13ddf405
3 changed files with 37 additions and 30 deletions

View File

@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
h.lock.Lock() h.lock.Lock()
defer h.lock.Unlock() defer h.lock.Unlock()
err := h.notifyHandler(c, data) err, finished := h.notifyHandler(c, data)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 解除任务锁定 // 解除任务锁定
if data.Status == Finished || data.Status == Stopped { if finished && (data.Status == Finished || data.Status == Stopped) {
h.redis.Del(c, service.MjRunningJobKey) h.redis.Del(c, service.MjRunningJobKey)
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error { func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
if err != nil { // 过期任务,丢弃 if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err) logger.Warn("任务已过期:", err)
return nil return nil, true
} }
var task service.MjTask var task service.MjTask
err = utils.JsonDecode(taskString, &task) err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃 if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err) logger.Warn("任务解析失败:", err)
return nil return nil, false
} }
if task.Src == service.TaskSrcImg { // 绘画任务 if task.Src == service.TaskSrcImg { // 绘画任务
@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
res := h.db.First(&job, task.Id) res := h.db.First(&job, task.Id)
if res.Error != nil { if res.Error != nil {
logger.Warn("非法任务:", err) logger.Warn("非法任务:", err)
return nil return nil, false
} }
job.MessageId = data.MessageId job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId job.ReferenceId = data.ReferenceId
job.Progress = data.Progress job.Progress = data.Progress
job.Prompt = data.Prompt job.Prompt = data.Prompt
// download image // 任务完成,将最终的图片下载下来
if data.Progress == 100 { if data.Progress == 100 {
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
if err != nil { if err != nil {
logger.Error("error with download img: ", err.Error()) logger.Error("error with download img: ", err.Error())
return err return err, false
} }
job.ImgURL = imgURL job.ImgURL = imgURL
} else { } else {
// 使用图片代理 // 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) job.ImgURL = data.Image.URL
} }
res = h.db.Updates(&job) res = h.db.Updates(&job)
if res.Error != nil { if res.Error != nil {
logger.Error("error with update job: ", err.Error()) logger.Error("error with update job: ", err.Error())
return res.Error return res.Error, false
} }
} else if task.Src == service.TaskSrcChat { // 聊天任务 } else if task.Src == service.TaskSrcChat { // 聊天任务
@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
res := h.db.Where("message_id = ?", data.MessageId).First(&job) res := h.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil { if res.Error == nil {
logger.Warn("重复消息:", data.MessageId) logger.Warn("重复消息:", data.MessageId)
return nil return nil, false
} }
wsClient := h.App.MjTaskClients.Get(task.Id) wsClient := h.App.MjTaskClients.Get(task.Id)
@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content) utils.ReplyMessage(wsClient, content)
} }
return err return err, false
} }
tx := h.db.Begin() tx := h.db.Begin()
@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
} }
res = tx.Create(&message) res = tx.Create(&message)
if res.Error != nil { if res.Error != nil {
return res.Error return res.Error, false
} }
// save the job // save the job
@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.ImgURL = imgURL job.ImgURL = imgURL
job.Progress = data.Progress job.Progress = data.Progress
job.Hash = data.Image.Hash
job.CreatedAt = time.Now() job.CreatedAt = time.Now()
res = tx.Create(&job) res = tx.Create(&job)
if res.Error != nil { if res.Error != nil {
tx.Rollback() tx.Rollback()
return res.Error return res.Error, false
} }
tx.Commit() tx.Commit()
} }
if wsClient == nil { // 客户端断线,则丢弃 if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data) logger.Errorf("Client is offline: %+v", data)
return nil return nil, true
} }
if data.Status == Finished { if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// delete client // 本次绘画完毕,移除客户端
h.App.MjTaskClients.Delete(task.Id) h.App.MjTaskClients.Delete(task.Id)
} else { } else {
data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) // 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
} }
} }
return nil return nil, true
} }
// Proxy 通过代理访问 discord f服务器图片
func (h *MidJourneyHandler) Proxy(c *gin.Context) { func (h *MidJourneyHandler) Proxy(c *gin.Context) {
logger.Info(c.Request.Host, c.Request.Proto) imgURL := c.Query("url")
return imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL)
url := c.Query("url")
image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
if err != nil { if err != nil {
c.String(http.StatusOK, err.Error()) c.String(http.StatusOK, err.Error())
return return
} }
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image)) c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
} }
type reqVo struct { type reqVo struct {

View File

@ -68,9 +68,8 @@ func (s *MjService) Run() {
logger.Info("Starting MidJourney job consumer.") logger.Info("Starting MidJourney job consumer.")
ctx := context.Background() ctx := context.Background()
for { for {
t, err := s.redis.Get(ctx, MjRunningJobKey).Result() _, err := s.redis.Get(ctx, MjRunningJobKey).Result()
if err == nil { if err == nil {
logger.Infof("An task is not finished: %s", t)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
@ -107,11 +106,12 @@ func (s *MjService) Run() {
task.RetryCount += 1 task.RetryCount += 1
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
// TODO: 执行失败通知聊天客户端 // TODO: 执行失败通知聊天客户端
time.Sleep(time.Second * 3)
continue continue
} }
// 锁定任务执行通道直到任务超时10分钟 // 锁定任务执行通道直到任务超时10分钟
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600) s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*10)
} }
} }

View File

@ -2,11 +2,12 @@ package main
import ( import (
"fmt" "fmt"
"path/filepath" "net/url"
) )
func main() { func main() {
imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png" parse, _ := url.Parse("http://localhost:5678/static")
fmt.Println(filepath.Ext(filepath.Base(imageURL))) imgURLPrefix := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
fmt.Println(imgURLPrefix)
} }