mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 09:16:39 +08:00
fix bug: remove timeout task ONLY for unfinished(progress < 100)
This commit is contained in:
parent
5230f90540
commit
cfe333e89f
@ -146,7 +146,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
if data.SRef != "" {
|
||||
params += fmt.Sprintf(" --sref %s", data.CRef)
|
||||
params += fmt.Sprintf(" --sref %s", data.SRef)
|
||||
}
|
||||
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
|
||||
params += fmt.Sprintf(" %s", data.Model)
|
||||
|
@ -36,7 +36,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
}
|
||||
cli := NewPlusClient(config)
|
||||
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
service := NewService(name, taskQueue, notifyQueue, db, cli)
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
@ -49,7 +49,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
}
|
||||
cli := NewProxyClient(config)
|
||||
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
service := NewService(name, taskQueue, notifyQueue, db, cli)
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -16,41 +15,26 @@ import (
|
||||
|
||||
// Service MJ 绘画服务
|
||||
type Service struct {
|
||||
Name string // service Name
|
||||
Client Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
HandledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
Name string // service Name
|
||||
Client Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
||||
return &Service{
|
||||
Name: name,
|
||||
db: db,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
Client: cli,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
Name: name,
|
||||
db: db,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
Client: cli,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||
for {
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
var task types.MjTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
@ -125,9 +109,6 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||
// 更新任务 ID/频道
|
||||
job.TaskId = res.Result
|
||||
job.MessageId = res.Result
|
||||
@ -136,27 +117,6 @@ func (s *Service) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
// remove the timeout tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(k)}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": "任务超时",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
Id string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
@ -187,6 +147,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
"progress": -1,
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||
}
|
||||
|
||||
@ -203,10 +164,6 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||
}
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
}
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
|
@ -146,6 +146,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
var errChan = make(chan error)
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
// send a request to sd api endpoint
|
||||
go func() {
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
@ -179,12 +180,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
errChan <- nil
|
||||
}()
|
||||
|
||||
// waiting for task finish
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan: // 任务完成
|
||||
if err != nil {
|
||||
case err := <-errChan:
|
||||
if err != nil { // task failed
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
return err
|
||||
}
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
|
Loading…
Reference in New Issue
Block a user