optimize code for remove timeout and failed image drawing job

This commit is contained in:
RockYang 2024-04-21 21:44:28 +08:00
parent f9da18ad52
commit 47c5a0387b
6 changed files with 45 additions and 14 deletions

View File

@ -10,7 +10,6 @@ import (
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"net/http" "net/http"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -196,10 +195,6 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
var jobs = make([]vo.DallJob, 0) var jobs = make([]vo.DallJob, 0)
for _, item := range items { for _, item := range items {
// delete failed or timeout tasks
if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
h.DB.Delete(&item)
}
var job vo.DallJob var job vo.DallJob
err := utils.CopyObject(item, &job) err := utils.CopyObject(item, &job)
if err != nil { if err != nil {

View File

@ -159,6 +159,7 @@ func main() {
service.Run() service.Run()
service.CheckTaskNotify() service.CheckTaskNotify()
service.DownloadImages() service.DownloadImages()
service.CheckTaskStatus()
}), }),
// 邮件服务 // 邮件服务

View File

@ -257,3 +257,44 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed}) s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
return imgURL, nil return imgURL, nil
} }
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@ -163,7 +163,6 @@ func (p *ServicePool) SyncTaskProgress() {
for _, job := range items { for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力 // 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
// 删除任务
p.db.Delete(&job) p.db.Delete(&job)
// 退回算力 // 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
@ -190,7 +189,7 @@ func (p *ServicePool) SyncTaskProgress() {
} }
} }
time.Sleep(time.Second) time.Sleep(time.Second * 10)
} }
}() }()
} }

View File

@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() {
continue continue
} }
} }
time.Sleep(time.Second * 10)
} }
}() }()
} }

View File

@ -184,12 +184,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
for { for {
select { select {
case err := <-errChan: case err := <-errChan:
if err != nil { // task failed if err != nil {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
return err return err
} }