mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-19 01:36:38 +08:00
optimize code for remove timeout and failed image drawing job
This commit is contained in:
parent
f9da18ad52
commit
47c5a0387b
@ -10,7 +10,6 @@ import (
|
|||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
@ -196,10 +195,6 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
|
|||||||
|
|
||||||
var jobs = make([]vo.DallJob, 0)
|
var jobs = make([]vo.DallJob, 0)
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
// delete failed or timeout tasks
|
|
||||||
if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
|
|
||||||
h.DB.Delete(&item)
|
|
||||||
}
|
|
||||||
var job vo.DallJob
|
var job vo.DallJob
|
||||||
err := utils.CopyObject(item, &job)
|
err := utils.CopyObject(item, &job)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -159,6 +159,7 @@ func main() {
|
|||||||
service.Run()
|
service.Run()
|
||||||
service.CheckTaskNotify()
|
service.CheckTaskNotify()
|
||||||
service.DownloadImages()
|
service.DownloadImages()
|
||||||
|
service.CheckTaskStatus()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 邮件服务
|
// 邮件服务
|
||||||
|
@ -257,3 +257,44 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
|||||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
|
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
|
||||||
return imgURL, nil
|
return imgURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (s *Service) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.SdJob
|
||||||
|
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务直接删除
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
|
||||||
|
s.db.Delete(&job)
|
||||||
|
var user model.User
|
||||||
|
s.db.Where("id = ?", job.UserId).First(&user)
|
||||||
|
// 退回绘图次数
|
||||||
|
res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
|
if res.Error == nil && res.RowsAffected > 0 {
|
||||||
|
s.db.Create(&model.PowerLog{
|
||||||
|
UserId: user.Id,
|
||||||
|
Username: user.Username,
|
||||||
|
Type: types.PowerConsume,
|
||||||
|
Amount: job.Power,
|
||||||
|
Balance: user.Power + job.Power,
|
||||||
|
Mark: types.PowerAdd,
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
@ -163,7 +163,6 @@ func (p *ServicePool) SyncTaskProgress() {
|
|||||||
for _, job := range items {
|
for _, job := range items {
|
||||||
// 失败或者 30 分钟还没完成的任务删除并退回算力
|
// 失败或者 30 分钟还没完成的任务删除并退回算力
|
||||||
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
|
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
|
||||||
// 删除任务
|
|
||||||
p.db.Delete(&job)
|
p.db.Delete(&job)
|
||||||
// 退回算力
|
// 退回算力
|
||||||
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
|
||||||
@ -190,7 +189,7 @@ func (p *ServicePool) SyncTaskProgress() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -184,12 +184,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil { // task failed
|
if err != nil {
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": err.Error(),
|
|
||||||
})
|
|
||||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user