optimize code for remove timeout and failed image drawing job

This commit is contained in:
RockYang 2024-04-21 21:44:28 +08:00
parent f9da18ad52
commit 47c5a0387b
6 changed files with 45 additions and 14 deletions

View File

@ -10,7 +10,6 @@ import (
"chatplus/utils"
"chatplus/utils/resp"
"net/http"
"time"
"github.com/gorilla/websocket"
@ -196,10 +195,6 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
var jobs = make([]vo.DallJob, 0)
for _, item := range items {
// delete failed or timeout tasks
if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
h.DB.Delete(&item)
}
var job vo.DallJob
err := utils.CopyObject(item, &job)
if err != nil {

View File

@ -159,6 +159,7 @@ func main() {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
service.CheckTaskStatus()
}),
// 邮件服务

View File

@ -257,3 +257,44 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
return imgURL, nil
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
s.db.Delete(&job)
var user model.User
s.db.Where("id = ?", job.UserId).First(&user)
// 退回绘图次数
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if res.Error == nil && res.RowsAffected > 0 {
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@ -163,7 +163,6 @@ func (p *ServicePool) SyncTaskProgress() {
for _, job := range items {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
// 删除任务
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
@ -190,7 +189,7 @@ func (p *ServicePool) SyncTaskProgress() {
}
}
time.Sleep(time.Second)
time.Sleep(time.Second * 10)
}
}()
}

View File

@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() {
continue
}
}
time.Sleep(time.Second * 10)
}
}()
}

View File

@ -184,12 +184,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
for {
select {
case err := <-errChan:
if err != nil { // task failed
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
if err != nil {
return err
}