mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-06 09:13:47 +08:00
feat: optimize mj notidy api, use job queue to send ai drawing request
This commit is contained in:
@@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
err := h.notifyHandler(c, data)
|
||||
err, finished := h.notifyHandler(c, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解除任务锁定
|
||||
if data.Status == Finished || data.Status == Stopped {
|
||||
if finished && (data.Status == Finished || data.Status == Stopped) {
|
||||
h.redis.Del(c, service.MjRunningJobKey)
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error {
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
|
||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var task service.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if task.Src == service.TaskSrcImg { // 绘画任务
|
||||
@@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
res := h.db.First(&job, task.Id)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", err)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
|
||||
// download image
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return err
|
||||
return err, false
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
} else {
|
||||
// 使用图片代理
|
||||
job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||
job.ImgURL = data.Image.URL
|
||||
}
|
||||
res = h.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", err.Error())
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
|
||||
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
||||
@@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||
if res.Error == nil {
|
||||
logger.Warn("重复消息:", data.MessageId)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
wsClient := h.App.MjTaskClients.Get(task.Id)
|
||||
@@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
return err
|
||||
return err, false
|
||||
}
|
||||
|
||||
tx := h.db.Begin()
|
||||
@@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
}
|
||||
res = tx.Create(&message)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
|
||||
// save the job
|
||||
@@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
job.Prompt = data.Prompt
|
||||
job.ImgURL = imgURL
|
||||
job.Progress = data.Progress
|
||||
job.Hash = data.Image.Hash
|
||||
job.CreatedAt = time.Now()
|
||||
res = tx.Create(&job)
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
if wsClient == nil { // 客户端断线,则丢弃
|
||||
logger.Errorf("Client is offline: %+v", data)
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
if data.Status == Finished {
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||
// delete client
|
||||
// 本次绘画完毕,移除客户端
|
||||
h.App.MjTaskClients.Delete(task.Id)
|
||||
} else {
|
||||
data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
||||
// 使用代理临时转发图片
|
||||
if data.Image.URL != "" {
|
||||
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Proxy 通过代理访问 discord f服务器图片
|
||||
func (h *MidJourneyHandler) Proxy(c *gin.Context) {
|
||||
logger.Info(c.Request.Host, c.Request.Proto)
|
||||
return
|
||||
url := c.Query("url")
|
||||
image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
|
||||
imgURL := c.Query("url")
|
||||
imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL)
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, err.Error())
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image))
|
||||
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||
}
|
||||
|
||||
type reqVo struct {
|
||||
|
||||
Reference in New Issue
Block a user