mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 09:16:39 +08:00
fix bug: remove timeout task ONLY for unfinished(progress < 100)
This commit is contained in:
parent
5230f90540
commit
cfe333e89f
@ -146,7 +146,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if data.SRef != "" {
|
if data.SRef != "" {
|
||||||
params += fmt.Sprintf(" --sref %s", data.CRef)
|
params += fmt.Sprintf(" --sref %s", data.SRef)
|
||||||
}
|
}
|
||||||
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
|
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
|
||||||
params += fmt.Sprintf(" %s", data.Model)
|
params += fmt.Sprintf(" %s", data.Model)
|
||||||
|
@ -36,7 +36,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
}
|
}
|
||||||
cli := NewPlusClient(config)
|
cli := NewPlusClient(config)
|
||||||
name := fmt.Sprintf("mj-plus-service-%d", k)
|
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
service := NewService(name, taskQueue, notifyQueue, db, cli)
|
||||||
go func() {
|
go func() {
|
||||||
service.Run()
|
service.Run()
|
||||||
}()
|
}()
|
||||||
@ -49,7 +49,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
}
|
}
|
||||||
cli := NewProxyClient(config)
|
cli := NewProxyClient(config)
|
||||||
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
service := NewService(name, taskQueue, notifyQueue, db, cli)
|
||||||
go func() {
|
go func() {
|
||||||
service.Run()
|
service.Run()
|
||||||
}()
|
}()
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -16,41 +15,26 @@ import (
|
|||||||
|
|
||||||
// Service MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Name string // service Name
|
Name string // service Name
|
||||||
Client Client // MJ Client
|
Client Client // MJ Client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
|
||||||
HandledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
|
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
Name: name,
|
Name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
Client: cli,
|
Client: cli,
|
||||||
taskTimeout: timeout,
|
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||||
for {
|
for {
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
var task types.MjTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -125,9 +109,6 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Infof("任务提交成功:%+v", res)
|
logger.Infof("任务提交成功:%+v", res)
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
|
||||||
// 更新任务 ID/频道
|
// 更新任务 ID/频道
|
||||||
job.TaskId = res.Result
|
job.TaskId = res.Result
|
||||||
job.MessageId = res.Result
|
job.MessageId = res.Result
|
||||||
@ -136,27 +117,6 @@ func (s *Service) Run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
|
||||||
func (s *Service) canHandleTask() bool {
|
|
||||||
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
|
||||||
return handledNum < s.maxHandleTaskNum
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the timeout tasks
|
|
||||||
func (s *Service) checkTasks() {
|
|
||||||
for k, t := range s.taskStartTimes {
|
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
|
||||||
delete(s.taskStartTimes, k)
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
|
|
||||||
s.db.Model(&model.MidJourneyJob{Id: uint(k)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": "任务超时",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
type CBReq struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Action string `json:"action"`
|
Action string `json:"action"`
|
||||||
@ -187,6 +147,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
|||||||
"progress": -1,
|
"progress": -1,
|
||||||
"err_msg": task.FailReason,
|
"err_msg": task.FailReason,
|
||||||
})
|
})
|
||||||
|
s.notifyQueue.RPush(job.UserId)
|
||||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,10 +164,6 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
|
|||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||||
}
|
}
|
||||||
if task.Status == "SUCCESS" {
|
|
||||||
// release lock task
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
}
|
|
||||||
// 通知前端更新任务进度
|
// 通知前端更新任务进度
|
||||||
if oldProgress != job.Progress {
|
if oldProgress != job.Progress {
|
||||||
s.notifyQueue.RPush(job.UserId)
|
s.notifyQueue.RPush(job.UserId)
|
||||||
|
@ -146,6 +146,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
var errChan = make(chan error)
|
var errChan = make(chan error)
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||||
logger.Debugf("send image request to %s", apiURL)
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
|
// send a request to sd api endpoint
|
||||||
go func() {
|
go func() {
|
||||||
response, err := s.httpClient.R().
|
response, err := s.httpClient.R().
|
||||||
SetHeader("Authorization", s.config.ApiKey).
|
SetHeader("Authorization", s.config.ApiKey).
|
||||||
@ -179,12 +180,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
errChan <- nil
|
errChan <- nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// waiting for task finish
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err := <-errChan: // 任务完成
|
case err := <-errChan:
|
||||||
if err != nil {
|
if err != nil { // task failed
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": -1,
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
})
|
||||||
|
s.notifyQueue.RPush(task.UserId)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// task finished
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
s.notifyQueue.RPush(task.UserId)
|
s.notifyQueue.RPush(task.UserId)
|
||||||
// 从 leveldb 中删除预览图片数据
|
// 从 leveldb 中删除预览图片数据
|
||||||
|
Loading…
Reference in New Issue
Block a user