mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 08:46:38 +08:00
feat: optimize mj notidy api, use job queue to send ai drawing request
This commit is contained in:
parent
1d3acc8ed3
commit
2e13ddf405
@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
err := h.notifyHandler(c, data)
|
||||
err, finished := h.notifyHandler(c, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解除任务锁定
|
||||
if data.Status == Finished || data.Status == Stopped {
|
||||
if finished && (data.Status == Finished || data.Status == Stopped) {
|
||||
h.redis.Del(c, service.MjRunningJobKey)
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error {
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
|
||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var task service.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if task.Src == service.TaskSrcImg { // 绘画任务
|
||||
@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
res := h.db.First(&job, task.Id)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", err)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
|
||||
// download image
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return err
|
||||
return err, false
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
} else {
|
||||
// 使用图片代理
|
||||
job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||
job.ImgURL = data.Image.URL
|
||||
}
|
||||
res = h.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", err.Error())
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
|
||||
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
||||
@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||
if res.Error == nil {
|
||||
logger.Warn("重复消息:", data.MessageId)
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
wsClient := h.App.MjTaskClients.Get(task.Id)
|
||||
@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
return err
|
||||
return err, false
|
||||
}
|
||||
|
||||
tx := h.db.Begin()
|
||||
@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
}
|
||||
res = tx.Create(&message)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
|
||||
// save the job
|
||||
@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
||||
job.Prompt = data.Prompt
|
||||
job.ImgURL = imgURL
|
||||
job.Progress = data.Progress
|
||||
job.Hash = data.Image.Hash
|
||||
job.CreatedAt = time.Now()
|
||||
res = tx.Create(&job)
|
||||
if res.Error != nil {
|
||||
tx.Rollback()
|
||||
return res.Error
|
||||
return res.Error, false
|
||||
}
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
if wsClient == nil { // 客户端断线,则丢弃
|
||||
logger.Errorf("Client is offline: %+v", data)
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
if data.Status == Finished {
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||
// delete client
|
||||
// 本次绘画完毕,移除客户端
|
||||
h.App.MjTaskClients.Delete(task.Id)
|
||||
} else {
|
||||
data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
||||
// 使用代理临时转发图片
|
||||
if data.Image.URL != "" {
|
||||
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// Proxy 通过代理访问 discord f服务器图片
|
||||
func (h *MidJourneyHandler) Proxy(c *gin.Context) {
|
||||
logger.Info(c.Request.Host, c.Request.Proto)
|
||||
return
|
||||
url := c.Query("url")
|
||||
image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
|
||||
imgURL := c.Query("url")
|
||||
imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL)
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, err.Error())
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image))
|
||||
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||
}
|
||||
|
||||
type reqVo struct {
|
||||
|
@ -68,9 +68,8 @@ func (s *MjService) Run() {
|
||||
logger.Info("Starting MidJourney job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
t, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
if err == nil {
|
||||
logger.Infof("An task is not finished: %s", t)
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
@ -107,11 +106,12 @@ func (s *MjService) Run() {
|
||||
task.RetryCount += 1
|
||||
s.taskQueue.RPush(task)
|
||||
// TODO: 执行失败通知聊天客户端
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 锁定任务执行通道,直到任务超时(10分钟)
|
||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600)
|
||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*10)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,11 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func main() {
|
||||
imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png"
|
||||
parse, _ := url.Parse("http://localhost:5678/static")
|
||||
|
||||
fmt.Println(filepath.Ext(filepath.Base(imageURL)))
|
||||
imgURLPrefix := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
||||
fmt.Println(imgURLPrefix)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user