feat: optimize mj notidy api, use job queue to send ai drawing request

This commit is contained in:
RockYang 2023-09-13 15:50:00 +08:00
parent 1d3acc8ed3
commit 2e13ddf405
3 changed files with 37 additions and 30 deletions

View File

@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
h.lock.Lock()
defer h.lock.Unlock()
err := h.notifyHandler(c, data)
err, finished := h.notifyHandler(c, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 解除任务锁定
if data.Status == Finished || data.Status == Stopped {
if finished && (data.Status == Finished || data.Status == Stopped) {
h.redis.Del(c, service.MjRunningJobKey)
}
resp.SUCCESS(c)
}
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error {
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return nil
return nil, true
}
var task service.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return nil
return nil, false
}
if task.Src == service.TaskSrcImg { // 绘画任务
@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
res := h.db.First(&job, task.Id)
if res.Error != nil {
logger.Warn("非法任务:", err)
return nil
return nil, false
}
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
// download image
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
if err != nil {
logger.Error("error with download img: ", err.Error())
return err
return err, false
}
job.ImgURL = imgURL
} else {
// 使用图片代理
job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = h.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", err.Error())
return res.Error
return res.Error, false
}
} else if task.Src == service.TaskSrcChat { // 聊天任务
@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil {
logger.Warn("重复消息:", data.MessageId)
return nil
return nil, false
}
wsClient := h.App.MjTaskClients.Get(task.Id)
@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
return err
return err, false
}
tx := h.db.Begin()
@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
}
res = tx.Create(&message)
if res.Error != nil {
return res.Error
return res.Error, false
}
// save the job
@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
job.Prompt = data.Prompt
job.ImgURL = imgURL
job.Progress = data.Progress
job.Hash = data.Image.Hash
job.CreatedAt = time.Now()
res = tx.Create(&job)
if res.Error != nil {
tx.Rollback()
return res.Error
return res.Error, false
}
tx.Commit()
}
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
return nil
return nil, true
}
if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// delete client
// 本次绘画完毕,移除客户端
h.App.MjTaskClients.Delete(task.Id)
} else {
data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
// 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
}
}
return nil
return nil, true
}
// Proxy 通过代理访问 discord f服务器图片
func (h *MidJourneyHandler) Proxy(c *gin.Context) {
logger.Info(c.Request.Host, c.Request.Proto)
return
url := c.Query("url")
image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
imgURL := c.Query("url")
imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL)
if err != nil {
c.String(http.StatusOK, err.Error())
return
}
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image))
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
type reqVo struct {

View File

@ -68,9 +68,8 @@ func (s *MjService) Run() {
logger.Info("Starting MidJourney job consumer.")
ctx := context.Background()
for {
t, err := s.redis.Get(ctx, MjRunningJobKey).Result()
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
if err == nil {
logger.Infof("An task is not finished: %s", t)
time.Sleep(time.Second * 3)
continue
}
@ -107,11 +106,12 @@ func (s *MjService) Run() {
task.RetryCount += 1
s.taskQueue.RPush(task)
// TODO: 执行失败通知聊天客户端
time.Sleep(time.Second * 3)
continue
}
// 锁定任务执行通道直到任务超时10分钟
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600)
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*10)
}
}

View File

@ -2,11 +2,12 @@ package main
import (
"fmt"
"path/filepath"
"net/url"
)
func main() {
imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png"
parse, _ := url.Parse("http://localhost:5678/static")
fmt.Println(filepath.Ext(filepath.Base(imageURL)))
imgURLPrefix := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
fmt.Println(imgURLPrefix)
}