diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index 5a738a7e..9401d610 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -10,7 +10,6 @@ import ( "chatplus/utils" "chatplus/utils/resp" "net/http" - "time" "github.com/gorilla/websocket" @@ -196,10 +195,6 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in var jobs = make([]vo.DallJob, 0) for _, item := range items { - // delete failed or timeout tasks - if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 { - h.DB.Delete(&item) - } var job vo.DallJob err := utils.CopyObject(item, &job) if err != nil { diff --git a/api/main.go b/api/main.go index 56bf7caf..7fd23eea 100644 --- a/api/main.go +++ b/api/main.go @@ -159,6 +159,7 @@ func main() { service.Run() service.CheckTaskNotify() service.DownloadImages() + service.CheckTaskStatus() }), // 邮件服务 diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 8e8c8c2f..f3929e84 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -257,3 +257,44 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed}) return imgURL, nil } + +// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 +func (s *Service) CheckTaskStatus() { + go func() { + logger.Info("Running Stable-Diffusion task status checking ...") + for { + var jobs []model.SdJob + res := s.db.Where("progress < ?", 100).Find(&jobs) + if res.Error != nil { + time.Sleep(5 * time.Second) + continue + } + + for _, job := range jobs { + // 5 分钟还没完成的任务直接删除 + if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { + s.db.Delete(&job) + var user model.User + s.db.Where("id = ?", job.UserId).First(&user) + // 退回绘图次数 + res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) + if res.Error == nil && res.RowsAffected > 0 { + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power + job.Power, + Mark: types.PowerAdd, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + continue + } + } + time.Sleep(time.Second * 10) + } + }() +} diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 5fcb681d..01bb9627 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -163,7 +163,6 @@ func (p *ServicePool) SyncTaskProgress() { for _, job := range items { // 失败或者 30 分钟还没完成的任务删除并退回算力 if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { - // 删除任务 p.db.Delete(&job) // 退回算力 tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) @@ -190,7 +189,7 @@ func (p *ServicePool) SyncTaskProgress() { } } - time.Sleep(time.Second) + time.Sleep(time.Second * 10) } }() } diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index 50c39a6b..e191eef8 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() { continue } } - + time.Sleep(time.Second * 10) } }() } diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 7ab82c95..9d6932a2 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -184,12 +184,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { for { select { case err := <-errChan: - if err != nil { // task failed - s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ - "progress": -1, - "err_msg": err.Error(), - }) - s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) + if err != nil { return err }