feat: add asynchronously pull midjourney task progress in case the synchronization callback is fails

This commit is contained in:
RockYang 2024-01-12 18:24:28 +08:00
parent d70035ff0c
commit 9929746b1d
6 changed files with 89 additions and 33 deletions

View File

@ -73,8 +73,7 @@ type StableDiffusionConfig struct {
} }
type MidJourneyPlusConfig struct { type MidJourneyPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务 Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
Name string // 服务名称,保持唯一
ApiURL string ApiURL string
ApiKey string ApiKey string
NotifyURL string // 任务进度更新回调地址 NotifyURL string // 任务进度更新回调地址

View File

@ -328,24 +328,14 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
h.db.Delete(&model.MidJourneyJob{Id: job.Id}) h.db.Delete(&model.MidJourneyJob{Id: job.Id})
} }
if item.Progress < 100 { if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item)
// 退回绘图次数
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
// 正在运行中任务使用代理访问图片 // 正在运行中任务使用代理访问图片
if item.ImgURL == "" && item.OrgURL != "" { if h.App.Config.ImgCdnURL != "" {
if h.App.Config.ImgCdnURL != "" { job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL)
job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL) } else {
} else { image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) if err == nil {
if err == nil { job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} }
} }
} }

View File

@ -45,7 +45,7 @@ type resBody struct {
} }
func (h *TestHandler) Test(c *gin.Context) { func (h *TestHandler) Test(c *gin.Context) {
query(c) image(c)
} }

View File

@ -162,6 +162,7 @@ func main() {
if pool.HasAvailableService() { if pool.HasAvailableService() {
pool.DownloadImages() pool.DownloadImages()
pool.CheckTaskNotify() pool.CheckTaskNotify()
pool.SyncTaskProgress()
} }
}), }),

View File

@ -15,20 +15,20 @@ import (
// Service MJ 绘画服务 // Service MJ 绘画服务
type Service struct { type Service struct {
name string // service name Name string // service Name
Client *Client // MJ Client Client *Client // MJ Client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64 taskTimeout int64
} }
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{ return &Service{
name: name, Name: name,
db: db, db: db,
taskQueue: taskQueue, taskQueue: taskQueue,
notifyQueue: notifyQueue, notifyQueue: notifyQueue,
@ -40,7 +40,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name) logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for { for {
s.checkTasks() s.checkTasks()
if !s.canHandleTask() { if !s.canHandleTask() {
@ -58,13 +58,14 @@ func (s *Service) Run() {
} }
// if it's reference message, check if it's this channel's message // if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name { if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
} }
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes var res ImageRes
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
@ -93,11 +94,11 @@ func (s *Service) Run() {
logger.Infof("任务提交成功:%+v", res) logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout // lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now() s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1) atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道 // 更新任务 ID/频道
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{ s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
"task_id": res.Result, "task_id": res.Result,
"channel_id": s.Client.Config.Name, "channel_id": s.Name,
}) })
} }
@ -105,7 +106,7 @@ func (s *Service) Run() {
// check if current service instance can handle more task // check if current service instance can handle more task
func (s *Service) canHandleTask() bool { func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum) handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum return handledNum < s.maxHandleTaskNum
} }
@ -114,7 +115,7 @@ func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes { for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout { if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k) delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1) atomic.AddInt32(&s.HandledTaskNum, -1)
// delete task from database // delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
} }
@ -156,7 +157,7 @@ func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
if data.Status == "SUCCESS" { if data.Status == "SUCCESS" {
// release lock task // release lock task
atomic.AddInt32(&s.handledTaskNum, -1) atomic.AddInt32(&s.HandledTaskNum, -1)
} }
s.notifyQueue.RPush(job.UserId) s.notifyQueue.RPush(job.UserId)

View File

@ -6,9 +6,11 @@ import (
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils"
"fmt" "fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"strings" "strings"
"sync/atomic"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@ -33,8 +35,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
if config.Enabled == false { if config.Enabled == false {
continue continue
} }
config.ApiURL = "https://one-api.bltcy.top"
client := plus.NewClient(config) client := plus.NewClient(config)
name := fmt.Sprintf("MidJourney Plus Service-%d", k) name := fmt.Sprintf("mj-service-plus-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
go func() { go func() {
servicePlus.Run() servicePlus.Run()
@ -169,6 +172,10 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
return fmt.Errorf("非法任务:%s", data.Id) return fmt.Errorf("非法任务:%s", data.Id)
} }
// 任务已经拉取完成
if job.Progress == 100 {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(data, job) return servicePlus.Notify(data, job)
} }
@ -176,10 +183,68 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
return nil return nil
} }
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(v.CreatedAt) > time.Minute*30 {
p.db.Delete(&v)
// 退回绘图次数
p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
if !strings.HasPrefix(v.ChannelId, "mj-service-plus") {
continue
}
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, err := servicePlus.Client.QueryTask(v.TaskId)
if err != nil {
continue
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
}
oldProgress := v.Progress
v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
v.Prompt = task.PromptEn
if task.ImageUrl != "" {
v.OrgURL = task.ImageUrl
}
v.UseProxy = true
v.MessageId = task.Id
p.db.Updates(&v)
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&servicePlus.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != v.Progress {
p.notifyQueue.RPush(v.UserId)
}
}
}
time.Sleep(time.Second)
}
}()
}
func (p *ServicePool) getServicePlus(name string) *plus.Service { func (p *ServicePool) getServicePlus(name string) *plus.Service {
for _, s := range p.services { for _, s := range p.services {
if servicePlus, ok := s.(*plus.Service); ok { if servicePlus, ok := s.(*plus.Service); ok {
if servicePlus.Client.Config.Name == name { if servicePlus.Name == name {
return servicePlus return servicePlus
} }
} }