fix bug: remove timeout task ONLY for unfinished(progress < 100)

This commit is contained in:
RockYang
2024-04-10 06:25:54 +08:00
parent 86b2d5eff2
commit 3e984a1bcb
4 changed files with 26 additions and 60 deletions

View File

@@ -8,7 +8,6 @@ import (
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
@@ -16,41 +15,26 @@ import (
// Service MJ 绘画服务
type Service struct {
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
HandledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
return &Service{
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0),
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
@@ -125,9 +109,6 @@ func (s *Service) Run() {
continue
}
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[int(task.Id)] = time.Now()
atomic.AddInt32(&s.HandledTaskNum, 1)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
@@ -136,27 +117,6 @@ func (s *Service) Run() {
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the timeout tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.HandledTaskNum, -1)
s.db.Model(&model.MidJourneyJob{Id: uint(k)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": "任务超时",
})
}
}
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
@@ -187,6 +147,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(job.UserId)
return fmt.Errorf("task failed: %v", task.FailReason)
}
@@ -203,10 +164,6 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)