mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: optimize mj notidy api, use job queue to send ai drawing request
This commit is contained in:
parent
1d3acc8ed3
commit
2e13ddf405
@ -87,32 +87,32 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
h.lock.Lock()
|
h.lock.Lock()
|
||||||
defer h.lock.Unlock()
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
err := h.notifyHandler(c, data)
|
err, finished := h.notifyHandler(c, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解除任务锁定
|
// 解除任务锁定
|
||||||
if data.Status == Finished || data.Status == Stopped {
|
if finished && (data.Status == Finished || data.Status == Stopped) {
|
||||||
h.redis.Del(c, service.MjRunningJobKey)
|
h.redis.Del(c, service.MjRunningJobKey)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error {
|
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
|
||||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||||
if err != nil { // 过期任务,丢弃
|
if err != nil { // 过期任务,丢弃
|
||||||
logger.Warn("任务已过期:", err)
|
logger.Warn("任务已过期:", err)
|
||||||
return nil
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
var task service.MjTask
|
var task service.MjTask
|
||||||
err = utils.JsonDecode(taskString, &task)
|
err = utils.JsonDecode(taskString, &task)
|
||||||
if err != nil { // 非标准任务,丢弃
|
if err != nil { // 非标准任务,丢弃
|
||||||
logger.Warn("任务解析失败:", err)
|
logger.Warn("任务解析失败:", err)
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if task.Src == service.TaskSrcImg { // 绘画任务
|
if task.Src == service.TaskSrcImg { // 绘画任务
|
||||||
@ -121,29 +121,29 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
|||||||
res := h.db.First(&job, task.Id)
|
res := h.db.First(&job, task.Id)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Warn("非法任务:", err)
|
logger.Warn("非法任务:", err)
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
job.MessageId = data.MessageId
|
job.MessageId = data.MessageId
|
||||||
job.ReferenceId = data.ReferenceId
|
job.ReferenceId = data.ReferenceId
|
||||||
job.Progress = data.Progress
|
job.Progress = data.Progress
|
||||||
job.Prompt = data.Prompt
|
job.Prompt = data.Prompt
|
||||||
|
|
||||||
// download image
|
// 任务完成,将最终的图片下载下来
|
||||||
if data.Progress == 100 {
|
if data.Progress == 100 {
|
||||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with download img: ", err.Error())
|
logger.Error("error with download img: ", err.Error())
|
||||||
return err
|
return err, false
|
||||||
}
|
}
|
||||||
job.ImgURL = imgURL
|
job.ImgURL = imgURL
|
||||||
} else {
|
} else {
|
||||||
// 使用图片代理
|
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||||
job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
job.ImgURL = data.Image.URL
|
||||||
}
|
}
|
||||||
res = h.db.Updates(&job)
|
res = h.db.Updates(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("error with update job: ", err.Error())
|
logger.Error("error with update job: ", err.Error())
|
||||||
return res.Error
|
return res.Error, false
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
||||||
@ -151,7 +151,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
|||||||
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
|
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
logger.Warn("重复消息:", data.MessageId)
|
logger.Warn("重复消息:", data.MessageId)
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
wsClient := h.App.MjTaskClients.Get(task.Id)
|
wsClient := h.App.MjTaskClients.Get(task.Id)
|
||||||
@ -168,7 +168,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
|||||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||||
utils.ReplyMessage(wsClient, content)
|
utils.ReplyMessage(wsClient, content)
|
||||||
}
|
}
|
||||||
return err
|
return err, false
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := h.db.Begin()
|
tx := h.db.Begin()
|
||||||
@ -185,7 +185,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
|||||||
}
|
}
|
||||||
res = tx.Create(&message)
|
res = tx.Create(&message)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
return res.Error
|
return res.Error, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the job
|
// save the job
|
||||||
@ -195,44 +195,50 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error
|
|||||||
job.Prompt = data.Prompt
|
job.Prompt = data.Prompt
|
||||||
job.ImgURL = imgURL
|
job.ImgURL = imgURL
|
||||||
job.Progress = data.Progress
|
job.Progress = data.Progress
|
||||||
|
job.Hash = data.Image.Hash
|
||||||
job.CreatedAt = time.Now()
|
job.CreatedAt = time.Now()
|
||||||
res = tx.Create(&job)
|
res = tx.Create(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return res.Error
|
return res.Error, false
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
if wsClient == nil { // 客户端断线,则丢弃
|
if wsClient == nil { // 客户端断线,则丢弃
|
||||||
logger.Errorf("Client is offline: %+v", data)
|
logger.Errorf("Client is offline: %+v", data)
|
||||||
return nil
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.Status == Finished {
|
if data.Status == Finished {
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||||
// delete client
|
// 本次绘画完毕,移除客户端
|
||||||
h.App.MjTaskClients.Delete(task.Id)
|
h.App.MjTaskClients.Delete(task.Id)
|
||||||
} else {
|
} else {
|
||||||
data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
|
// 使用代理临时转发图片
|
||||||
|
if data.Image.URL != "" {
|
||||||
|
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
|
||||||
|
if err == nil {
|
||||||
|
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
|
}
|
||||||
|
}
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Proxy 通过代理访问 discord f服务器图片
|
||||||
func (h *MidJourneyHandler) Proxy(c *gin.Context) {
|
func (h *MidJourneyHandler) Proxy(c *gin.Context) {
|
||||||
logger.Info(c.Request.Host, c.Request.Proto)
|
imgURL := c.Query("url")
|
||||||
return
|
imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL)
|
||||||
url := c.Query("url")
|
|
||||||
image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusOK, err.Error())
|
c.String(http.StatusOK, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image))
|
c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
}
|
}
|
||||||
|
|
||||||
type reqVo struct {
|
type reqVo struct {
|
||||||
|
@ -68,9 +68,8 @@ func (s *MjService) Run() {
|
|||||||
logger.Info("Starting MidJourney job consumer.")
|
logger.Info("Starting MidJourney job consumer.")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
for {
|
for {
|
||||||
t, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Infof("An task is not finished: %s", t)
|
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -107,11 +106,12 @@ func (s *MjService) Run() {
|
|||||||
task.RetryCount += 1
|
task.RetryCount += 1
|
||||||
s.taskQueue.RPush(task)
|
s.taskQueue.RPush(task)
|
||||||
// TODO: 执行失败通知聊天客户端
|
// TODO: 执行失败通知聊天客户端
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 锁定任务执行通道,直到任务超时(10分钟)
|
// 锁定任务执行通道,直到任务超时(10分钟)
|
||||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600)
|
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*10)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,12 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png"
|
parse, _ := url.Parse("http://localhost:5678/static")
|
||||||
|
|
||||||
fmt.Println(filepath.Ext(filepath.Base(imageURL)))
|
imgURLPrefix := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)
|
||||||
|
fmt.Println(imgURLPrefix)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user