feat: add asynchronously pull midjourney task progress in case the synchronization callback is fails

This commit is contained in:
RockYang
2024-01-12 18:24:28 +08:00
parent d70035ff0c
commit 9929746b1d
6 changed files with 89 additions and 33 deletions

View File

@@ -15,20 +15,20 @@ import (
// Service MJ 绘画服务
type Service struct {
name string // service name
Name string // service Name
Client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{
name: name,
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
@@ -40,7 +40,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name)
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
s.checkTasks()
if !s.canHandleTask() {
@@ -58,13 +58,14 @@ func (s *Service) Run() {
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name {
if task.ChannelId != "" && task.ChannelId != s.Name {
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
@@ -93,11 +94,11 @@ func (s *Service) Run() {
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
"task_id": res.Result,
"channel_id": s.Client.Config.Name,
"channel_id": s.Name,
})
}
@@ -105,7 +106,7 @@ func (s *Service) Run() {
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum
}
@@ -114,7 +115,7 @@ func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
atomic.AddInt32(&s.HandledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
@@ -156,7 +157,7 @@ func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
if data.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.handledTaskNum, -1)
atomic.AddInt32(&s.HandledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId)