mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-07 00:36:08 +08:00
重构异步任务更新方式,使用 Http 替代 websocket
This commit is contained in:
@@ -16,9 +16,10 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -30,7 +31,6 @@ var logger = logger2.GetLogger()
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
wsService *service.WebsocketService
|
||||
@@ -41,7 +41,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
|
||||
return &Service{
|
||||
httpClient: req.C(),
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
@@ -102,8 +101,6 @@ func (s *Service) Run() {
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -225,15 +222,12 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
resp, err := s.checkTaskProgress(apiKey)
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -242,7 +236,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, error) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().
|
||||
@@ -250,13 +244,13 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressRe
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return nil, err
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||
return nil, fmt.Errorf("error http code status: %v", response.Status)
|
||||
}
|
||||
|
||||
return nil, &res
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
@@ -264,25 +258,6 @@ func (s *Service) PushTask(task types.SdTask) {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
@@ -297,7 +272,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 5 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*5 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
|
||||
Reference in New Issue
Block a user