diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 7988a380..fa8762c9 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -146,7 +146,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { } if data.SRef != "" { - params += fmt.Sprintf(" --sref %s", data.CRef) + params += fmt.Sprintf(" --sref %s", data.SRef) } if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") { params += fmt.Sprintf(" %s", data.Model) diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 0143467f..7404021e 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -36,7 +36,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa } cli := NewPlusClient(config) name := fmt.Sprintf("mj-plus-service-%d", k) - service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) + service := NewService(name, taskQueue, notifyQueue, db, cli) go func() { service.Run() }() @@ -49,7 +49,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa } cli := NewProxyClient(config) name := fmt.Sprintf("mj-proxy-service-%d", k) - service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) + service := NewService(name, taskQueue, notifyQueue, db, cli) go func() { service.Run() }() diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 30f8f265..ad118308 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -8,7 +8,6 @@ import ( "chatplus/utils" "fmt" "strings" - "sync/atomic" "time" "gorm.io/gorm" @@ -16,41 +15,26 @@ import ( // Service MJ 绘画服务 type Service struct { - Name string // service Name - Client Client // MJ Client - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - maxHandleTaskNum int32 // max task number current service can handle - HandledTaskNum int32 // already handled task number - taskStartTimes map[int]time.Time // task start time, to check if the task is timeout - taskTimeout int64 + Name string // service Name + Client Client // MJ Client + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB } -func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service { +func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { return &Service{ - Name: name, - db: db, - taskQueue: taskQueue, - notifyQueue: notifyQueue, - Client: cli, - taskTimeout: timeout, - maxHandleTaskNum: maxTaskNum, - taskStartTimes: make(map[int]time.Time, 0), + Name: name, + db: db, + taskQueue: taskQueue, + notifyQueue: notifyQueue, + Client: cli, } } func (s *Service) Run() { logger.Infof("Starting MidJourney job consumer for %s", s.Name) for { - s.checkTasks() - if !s.canHandleTask() { - // current service is full, can not handle more task - // waiting for running task finish - time.Sleep(time.Second * 3) - continue - } - var task types.MjTask err := s.taskQueue.LPop(&task) if err != nil { @@ -125,9 +109,6 @@ func (s *Service) Run() { continue } logger.Infof("任务提交成功:%+v", res) - // lock the task until the execute timeout - s.taskStartTimes[int(task.Id)] = time.Now() - atomic.AddInt32(&s.HandledTaskNum, 1) // 更新任务 ID/频道 job.TaskId = res.Result job.MessageId = res.Result @@ -136,27 +117,6 @@ func (s *Service) Run() { } } -// check if current service instance can handle more task -func (s *Service) canHandleTask() bool { - handledNum := atomic.LoadInt32(&s.HandledTaskNum) - return handledNum < s.maxHandleTaskNum -} - -// remove the timeout tasks -func (s *Service) checkTasks() { - for k, t := range s.taskStartTimes { - if time.Now().Unix()-t.Unix() > s.taskTimeout { - delete(s.taskStartTimes, k) - atomic.AddInt32(&s.HandledTaskNum, -1) - - s.db.Model(&model.MidJourneyJob{Id: uint(k)}).UpdateColumns(map[string]interface{}{ - "progress": -1, - "err_msg": "任务超时", - }) - } - } -} - type CBReq struct { Id string `json:"id"` Action string `json:"action"` @@ -187,6 +147,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error { "progress": -1, "err_msg": task.FailReason, }) + s.notifyQueue.RPush(job.UserId) return fmt.Errorf("task failed: %v", task.FailReason) } @@ -203,10 +164,6 @@ func (s *Service) Notify(job model.MidJourneyJob) error { if tx.Error != nil { return fmt.Errorf("error with update database: %v", tx.Error) } - if task.Status == "SUCCESS" { - // release lock task - atomic.AddInt32(&s.HandledTaskNum, -1) - } // 通知前端更新任务进度 if oldProgress != job.Progress { s.notifyQueue.RPush(job.UserId) diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 34d47697..f6d3b081 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -146,6 +146,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { var errChan = make(chan error) apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) logger.Debugf("send image request to %s", apiURL) + // send a request to sd api endpoint go func() { response, err := s.httpClient.R(). SetHeader("Authorization", s.config.ApiKey). @@ -179,12 +180,20 @@ func (s *Service) Txt2Img(task types.SdTask) error { errChan <- nil }() + // waiting for task finish for { select { - case err := <-errChan: // 任务完成 - if err != nil { + case err := <-errChan: + if err != nil { // task failed + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": err.Error(), + }) + s.notifyQueue.RPush(task.UserId) return err } + + // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.notifyQueue.RPush(task.UserId) // 从 leveldb 中删除预览图片数据