diff --git a/api/core/types/config.go b/api/core/types/config.go index c8ade652..78b708d4 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -73,8 +73,7 @@ type StableDiffusionConfig struct { } type MidJourneyPlusConfig struct { - Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 - Name string // 服务名称,保持唯一 + Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 ApiURL string ApiKey string NotifyURL string // 任务进度更新回调地址 diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index b49a08c9..d1e3d13b 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -328,24 +328,14 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { h.db.Delete(&model.MidJourneyJob{Id: job.Id}) } - if item.Progress < 100 { - // 10 分钟还没完成的任务直接删除 - if time.Now().Sub(item.CreatedAt) > time.Minute*10 { - h.db.Delete(&item) - // 退回绘图次数 - h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) - continue - } - + if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { // 正在运行中任务使用代理访问图片 - if item.ImgURL == "" && item.OrgURL != "" { - if h.App.Config.ImgCdnURL != "" { - job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL) - } else { - image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) - if err == nil { - job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } + if h.App.Config.ImgCdnURL != "" { + job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL) + } else { + image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) } } } diff --git a/api/handler/test_handler.go b/api/handler/test_handler.go index 1c13f11f..578f27a9 100644 --- a/api/handler/test_handler.go +++ b/api/handler/test_handler.go @@ -45,7 +45,7 @@ type resBody struct { } func (h *TestHandler) Test(c *gin.Context) { - query(c) + image(c) } diff --git a/api/main.go b/api/main.go index a3cc6b83..e4984dbc 100644 --- a/api/main.go +++ b/api/main.go @@ -162,6 +162,7 @@ func main() { if pool.HasAvailableService() { pool.DownloadImages() pool.CheckTaskNotify() + pool.SyncTaskProgress() } }), diff --git a/api/service/mj/plus/service.go b/api/service/mj/plus/service.go index 788d4355..2a6a0435 100644 --- a/api/service/mj/plus/service.go +++ b/api/service/mj/plus/service.go @@ -15,20 +15,20 @@ import ( // Service MJ 绘画服务 type Service struct { - name string // service name + Name string // service Name Client *Client // MJ Client taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB maxHandleTaskNum int32 // max task number current service can handle - handledTaskNum int32 // already handled task number + HandledTaskNum int32 // already handled task number taskStartTimes map[int]time.Time // task start time, to check if the task is timeout taskTimeout int64 } func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { return &Service{ - name: name, + Name: name, db: db, taskQueue: taskQueue, notifyQueue: notifyQueue, @@ -40,7 +40,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red } func (s *Service) Run() { - logger.Infof("Starting MidJourney job consumer for %s", s.name) + logger.Infof("Starting MidJourney job consumer for %s", s.Name) for { s.checkTasks() if !s.canHandleTask() { @@ -58,13 +58,14 @@ func (s *Service) Run() { } // if it's reference message, check if it's this channel's message - if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name { + if task.ChannelId != "" && task.ChannelId != s.Name { + logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) s.taskQueue.RPush(task) time.Sleep(time.Second) continue } - logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) + logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) var res ImageRes switch task.Type { case types.TaskImage: @@ -93,11 +94,11 @@ func (s *Service) Run() { logger.Infof("任务提交成功:%+v", res) // lock the task until the execute timeout s.taskStartTimes[task.Id] = time.Now() - atomic.AddInt32(&s.handledTaskNum, 1) + atomic.AddInt32(&s.HandledTaskNum, 1) // 更新任务 ID/频道 s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{ "task_id": res.Result, - "channel_id": s.Client.Config.Name, + "channel_id": s.Name, }) } @@ -105,7 +106,7 @@ func (s *Service) Run() { // check if current service instance can handle more task func (s *Service) canHandleTask() bool { - handledNum := atomic.LoadInt32(&s.handledTaskNum) + handledNum := atomic.LoadInt32(&s.HandledTaskNum) return handledNum < s.maxHandleTaskNum } @@ -114,7 +115,7 @@ func (s *Service) checkTasks() { for k, t := range s.taskStartTimes { if time.Now().Unix()-t.Unix() > s.taskTimeout { delete(s.taskStartTimes, k) - atomic.AddInt32(&s.handledTaskNum, -1) + atomic.AddInt32(&s.HandledTaskNum, -1) // delete task from database s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") } @@ -156,7 +157,7 @@ func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error { if data.Status == "SUCCESS" { // release lock task - atomic.AddInt32(&s.handledTaskNum, -1) + atomic.AddInt32(&s.HandledTaskNum, -1) } s.notifyQueue.RPush(job.UserId) diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index b904640a..4861bb59 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -6,9 +6,11 @@ import ( "chatplus/service/oss" "chatplus/store" "chatplus/store/model" + "chatplus/utils" "fmt" "github.com/go-redis/redis/v8" "strings" + "sync/atomic" "time" "gorm.io/gorm" @@ -33,8 +35,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa if config.Enabled == false { continue } + config.ApiURL = "https://one-api.bltcy.top" client := plus.NewClient(config) - name := fmt.Sprintf("MidJourney Plus Service-%d", k) + name := fmt.Sprintf("mj-service-plus-%d", k) servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) go func() { servicePlus.Run() @@ -169,6 +172,10 @@ func (p *ServicePool) Notify(data plus.CBReq) error { return fmt.Errorf("非法任务:%s", data.Id) } + // 任务已经拉取完成 + if job.Progress == 100 { + return nil + } if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { return servicePlus.Notify(data, job) } @@ -176,10 +183,68 @@ func (p *ServicePool) Notify(data plus.CBReq) error { return nil } +// SyncTaskProgress 异步拉取任务 +func (p *ServicePool) SyncTaskProgress() { + go func() { + var items []model.MidJourneyJob + for { + res := p.db.Where("progress < ?", 100).Find(&items) + if res.Error != nil { + continue + } + + for _, v := range items { + // 30 分钟还没完成的任务直接删除 + if time.Now().Sub(v.CreatedAt) > time.Minute*30 { + p.db.Delete(&v) + // 退回绘图次数 + p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) + continue + } + + if !strings.HasPrefix(v.ChannelId, "mj-service-plus") { + continue + } + + if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil { + task, err := servicePlus.Client.QueryTask(v.TaskId) + if err != nil { + continue + } + if len(task.Buttons) > 0 { + v.Hash = getImageHash(task.Buttons[0].CustomId) + } + oldProgress := v.Progress + v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) + v.Prompt = task.PromptEn + if task.ImageUrl != "" { + v.OrgURL = task.ImageUrl + } + v.UseProxy = true + v.MessageId = task.Id + + p.db.Updates(&v) + + if task.Status == "SUCCESS" { + // release lock task + atomic.AddInt32(&servicePlus.HandledTaskNum, -1) + } + // 通知前端更新任务进度 + if oldProgress != v.Progress { + p.notifyQueue.RPush(v.UserId) + } + } + } + + time.Sleep(time.Second) + } + }() +} + func (p *ServicePool) getServicePlus(name string) *plus.Service { for _, s := range p.services { if servicePlus, ok := s.(*plus.Service); ok { - if servicePlus.Client.Config.Name == name { + if servicePlus.Name == name { return servicePlus } }