mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	fix bug: remove timeout task ONLY for unfinished(progress < 100)
This commit is contained in:
		@@ -146,7 +146,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.SRef != "" {
 | 
			
		||||
		params += fmt.Sprintf(" --sref %s", data.CRef)
 | 
			
		||||
		params += fmt.Sprintf(" --sref %s", data.SRef)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
 | 
			
		||||
		params += fmt.Sprintf(" %s", data.Model)
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
			
		||||
		}
 | 
			
		||||
		cli := NewPlusClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-plus-service-%d", k)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
@@ -49,7 +49,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
			
		||||
		}
 | 
			
		||||
		cli := NewProxyClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-proxy-service-%d", k)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -16,41 +15,26 @@ import (
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	Name             string // service Name
 | 
			
		||||
	Client           Client // MJ Client
 | 
			
		||||
	taskQueue        *store.RedisQueue
 | 
			
		||||
	notifyQueue      *store.RedisQueue
 | 
			
		||||
	db               *gorm.DB
 | 
			
		||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
			
		||||
	HandledTaskNum   int32             // already handled task number
 | 
			
		||||
	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
			
		||||
	taskTimeout      int64
 | 
			
		||||
	Name        string // service Name
 | 
			
		||||
	Client      Client // MJ Client
 | 
			
		||||
	taskQueue   *store.RedisQueue
 | 
			
		||||
	notifyQueue *store.RedisQueue
 | 
			
		||||
	db          *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		Name:             name,
 | 
			
		||||
		db:               db,
 | 
			
		||||
		taskQueue:        taskQueue,
 | 
			
		||||
		notifyQueue:      notifyQueue,
 | 
			
		||||
		Client:           cli,
 | 
			
		||||
		taskTimeout:      timeout,
 | 
			
		||||
		maxHandleTaskNum: maxTaskNum,
 | 
			
		||||
		taskStartTimes:   make(map[int]time.Time, 0),
 | 
			
		||||
		Name:        name,
 | 
			
		||||
		db:          db,
 | 
			
		||||
		taskQueue:   taskQueue,
 | 
			
		||||
		notifyQueue: notifyQueue,
 | 
			
		||||
		Client:      cli,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.Name)
 | 
			
		||||
	for {
 | 
			
		||||
		s.checkTasks()
 | 
			
		||||
		if !s.canHandleTask() {
 | 
			
		||||
			// current service is full, can not handle more task
 | 
			
		||||
			// waiting for running task finish
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -125,9 +109,6 @@ func (s *Service) Run() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
		// lock the task until the execute timeout
 | 
			
		||||
		s.taskStartTimes[int(task.Id)] = time.Now()
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, 1)
 | 
			
		||||
		// 更新任务 ID/频道
 | 
			
		||||
		job.TaskId = res.Result
 | 
			
		||||
		job.MessageId = res.Result
 | 
			
		||||
@@ -136,27 +117,6 @@ func (s *Service) Run() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check if current service instance can handle more task
 | 
			
		||||
func (s *Service) canHandleTask() bool {
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.HandledTaskNum)
 | 
			
		||||
	return handledNum < s.maxHandleTaskNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// remove the timeout tasks
 | 
			
		||||
func (s *Service) checkTasks() {
 | 
			
		||||
	for k, t := range s.taskStartTimes {
 | 
			
		||||
		if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
			
		||||
			delete(s.taskStartTimes, k)
 | 
			
		||||
			atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: uint(k)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"progress": -1,
 | 
			
		||||
				"err_msg":  "任务超时",
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	Id          string      `json:"id"`
 | 
			
		||||
	Action      string      `json:"action"`
 | 
			
		||||
@@ -187,6 +147,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  task.FailReason,
 | 
			
		||||
		})
 | 
			
		||||
		s.notifyQueue.RPush(job.UserId)
 | 
			
		||||
		return fmt.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -203,10 +164,6 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with update database: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
	if task.Status == "SUCCESS" {
 | 
			
		||||
		// release lock task
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
	}
 | 
			
		||||
	// 通知前端更新任务进度
 | 
			
		||||
	if oldProgress != job.Progress {
 | 
			
		||||
		s.notifyQueue.RPush(job.UserId)
 | 
			
		||||
 
 | 
			
		||||
@@ -146,6 +146,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
	var errChan = make(chan error)
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
 | 
			
		||||
	logger.Debugf("send image request to %s", apiURL)
 | 
			
		||||
	// send a request to sd api endpoint
 | 
			
		||||
	go func() {
 | 
			
		||||
		response, err := s.httpClient.R().
 | 
			
		||||
			SetHeader("Authorization", s.config.ApiKey).
 | 
			
		||||
@@ -179,12 +180,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
		errChan <- nil
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// waiting for task finish
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case err := <-errChan: // 任务完成
 | 
			
		||||
			if err != nil {
 | 
			
		||||
		case err := <-errChan:
 | 
			
		||||
			if err != nil { // task failed
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": -1,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// task finished
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			// 从 leveldb 中删除预览图片数据
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user