mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
feat: add asynchronously pull midjourney task progress in case the synchronization callback is fails
This commit is contained in:
@@ -6,9 +6,11 @@ import (
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -33,8 +35,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
config.ApiURL = "https://one-api.bltcy.top"
|
||||
client := plus.NewClient(config)
|
||||
name := fmt.Sprintf("MidJourney Plus Service-%d", k)
|
||||
name := fmt.Sprintf("mj-service-plus-%d", k)
|
||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
||||
go func() {
|
||||
servicePlus.Run()
|
||||
@@ -169,6 +172,10 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
|
||||
return fmt.Errorf("非法任务:%s", data.Id)
|
||||
}
|
||||
|
||||
// 任务已经拉取完成
|
||||
if job.Progress == 100 {
|
||||
return nil
|
||||
}
|
||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
||||
return servicePlus.Notify(data, job)
|
||||
}
|
||||
@@ -176,10 +183,68 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (p *ServicePool) SyncTaskProgress() {
|
||||
go func() {
|
||||
var items []model.MidJourneyJob
|
||||
for {
|
||||
res := p.db.Where("progress < ?", 100).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
// 30 分钟还没完成的任务直接删除
|
||||
if time.Now().Sub(v.CreatedAt) > time.Minute*30 {
|
||||
p.db.Delete(&v)
|
||||
// 退回绘图次数
|
||||
p.db.Model(&model.User{}).Where("id = ?", v.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(v.ChannelId, "mj-service-plus") {
|
||||
continue
|
||||
}
|
||||
|
||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
||||
task, err := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = getImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := v.Progress
|
||||
v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
v.Prompt = task.PromptEn
|
||||
if task.ImageUrl != "" {
|
||||
v.OrgURL = task.ImageUrl
|
||||
}
|
||||
v.UseProxy = true
|
||||
v.MessageId = task.Id
|
||||
|
||||
p.db.Updates(&v)
|
||||
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&servicePlus.HandledTaskNum, -1)
|
||||
}
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != v.Progress {
|
||||
p.notifyQueue.RPush(v.UserId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
||||
for _, s := range p.services {
|
||||
if servicePlus, ok := s.(*plus.Service); ok {
|
||||
if servicePlus.Client.Config.Name == name {
|
||||
if servicePlus.Name == name {
|
||||
return servicePlus
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user