DO NOT refresh finished jobs when job is running

This commit is contained in:
RockYang
2024-04-20 21:30:55 +08:00
parent caa538a1d0
commit d02cb573fd
10 changed files with 816 additions and 64 deletions

View File

@@ -4,6 +4,7 @@ import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/store"
"chatplus/store/model"
"fmt"
@@ -69,16 +70,16 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(userId)
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte("Task Updated"))
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
@@ -127,7 +128,7 @@ func (p *ServicePool) DownloadImages() {
if cli == nil {
continue
}
err = cli.Send([]byte("Task Updated"))
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}

View File

@@ -3,6 +3,7 @@ package mj
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/sd"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
@@ -105,7 +106,7 @@ func (s *Service) Run() {
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
}
logger.Infof("任务提交成功:%+v", res)
@@ -147,7 +148,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
"progress": -1,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(job.UserId)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
@@ -166,7 +167,11 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
s.notifyQueue.RPush(job.UserId)
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}

View File

@@ -60,16 +60,16 @@ func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
var message NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(userId)
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = client.Send([]byte(message.Message))
if err != nil {
continue
}

View File

@@ -81,7 +81,7 @@ func (s *Service) Run() {
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
continue
}
}
@@ -189,13 +189,13 @@ func (s *Service) Txt2Img(task types.SdTask) error {
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
return err
}
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
@@ -205,7 +205,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(task.UserId)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)

View File

@@ -4,44 +4,14 @@ import logger2 "chatplus/logger"
var logger = logger2.GetLogger()
type TaskInfo struct {
UserId uint `json:"user_id"`
SessionId string `json:"session_id"`
JobId int `json:"job_id"`
TaskId string `json:"task_id"`
Data []interface{} `json:"data"`
EventData interface{} `json:"event_data"`
FnIndex int `json:"fn_index"`
SessionHash string `json:"session_hash"`
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
type CBReq struct {
UserId uint
SessionId string
JobId int
TaskId string
ImageName string
ImageData string
Progress int
Seed int64
Success bool
Message string
}
var ParamKeys = map[string]int{
"task_id": 0,
"prompt": 1,
"negative_prompt": 2,
"steps": 4,
"sampler": 5,
"face_fix": 7, // 面部修复
"cfg_scale": 8,
"seed": 27,
"height": 10,
"width": 9,
"hd_fix": 11,
"hd_redraw_rate": 12, //高清修复重绘幅度
"hd_scale": 13, // 高清修复放大倍数
"hd_scale_alg": 14, // 高清修复放大算法
"hd_sample_num": 15, // 高清修复采样次数
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)