mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-07 08:46:04 +08:00
重构异步任务更新方式,使用 Http 替代 websocket
This commit is contained in:
@@ -15,10 +15,11 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -26,23 +27,19 @@ import (
|
||||
type Service struct {
|
||||
client *Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
wsService *service.WebsocketService
|
||||
uploaderManager *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||
client: client,
|
||||
wsService: wsService,
|
||||
uploaderManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
@@ -59,7 +56,6 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
task.Id = v.Id
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
s.PushTask(task)
|
||||
}
|
||||
|
||||
@@ -96,7 +92,6 @@ func (s *Service) Run() {
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
}
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
@@ -139,7 +134,6 @@ func (s *Service) Run() {
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
@@ -178,24 +172,6 @@ func GetImageHash(action string) string {
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.MidJourneyJob
|
||||
@@ -228,12 +204,6 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
v.ImgURL = imgURL
|
||||
s.db.Updates(&v)
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[v.Id],
|
||||
UserId: v.UserId,
|
||||
JobId: int(v.Id),
|
||||
Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
@@ -259,7 +229,7 @@ func (s *Service) SyncTaskProgress() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 10 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
@@ -279,18 +249,12 @@ func (s *Service) SyncTaskProgress() {
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
logger.Errorf("task failed: %v", task.FailReason)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(task.Buttons) > 0 {
|
||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := job.Progress
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
if task.ImageUrl != "" {
|
||||
job.OrgURL = task.ImageUrl
|
||||
@@ -300,19 +264,6 @@ func (s *Service) SyncTaskProgress() {
|
||||
logger.Errorf("error with update database: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
message := service.TaskStatusRunning
|
||||
if job.Progress == 100 {
|
||||
message = service.TaskStatusFinished
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: message})
|
||||
}
|
||||
}
|
||||
|
||||
// 找出失败的任务,并恢复其扣减算力
|
||||
|
||||
Reference in New Issue
Block a user