mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	optimize code for remove timeout and failed image drawing job
This commit is contained in:
		@@ -10,7 +10,6 @@ import (
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
@@ -196,10 +195,6 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.DallJob, 0)
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		// delete failed or timeout tasks
 | 
			
		||||
		if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
 | 
			
		||||
			h.DB.Delete(&item)
 | 
			
		||||
		}
 | 
			
		||||
		var job vo.DallJob
 | 
			
		||||
		err := utils.CopyObject(item, &job)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -159,6 +159,7 @@ func main() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
			service.CheckTaskNotify()
 | 
			
		||||
			service.DownloadImages()
 | 
			
		||||
			service.CheckTaskStatus()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 邮件服务
 | 
			
		||||
 
 | 
			
		||||
@@ -257,3 +257,44 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
 | 
			
		||||
	s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
 | 
			
		||||
	return imgURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务直接删除
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
 | 
			
		||||
					s.db.Delete(&job)
 | 
			
		||||
					var user model.User
 | 
			
		||||
					s.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
					// 退回绘图次数
 | 
			
		||||
					res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if res.Error == nil && res.RowsAffected > 0 {
 | 
			
		||||
						s.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "dall-e-3",
 | 
			
		||||
							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -163,7 +163,6 @@ func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
			for _, job := range items {
 | 
			
		||||
				// 失败或者 30 分钟还没完成的任务删除并退回算力
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
 | 
			
		||||
					// 删除任务
 | 
			
		||||
					p.db.Delete(&job)
 | 
			
		||||
					// 退回算力
 | 
			
		||||
					tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
@@ -190,7 +189,7 @@ func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -184,12 +184,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case err := <-errChan:
 | 
			
		||||
			if err != nil { // task failed
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": -1,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user