mj websocket refactor is ready

This commit is contained in:
RockYang
2024-09-29 07:51:08 +08:00
parent 8fffa60569
commit 00a8bc6784
7 changed files with 63 additions and 149 deletions

View File

@@ -28,18 +28,20 @@ type Service struct {
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
wsService *service.WebsocketService
uploaderManager *oss.UploaderManager
clientIds map[uint]string
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service {
return &Service{
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
client: client,
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploaderManager: manager,
clientIds: map[uint]string{},
}
}
@@ -77,6 +79,7 @@ func (s *Service) Run() {
if task.Mode == "" {
task.Mode = "fast"
}
s.clientIds[task.Id] = task.ClientId
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
@@ -119,7 +122,7 @@ func (s *Service) Run() {
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
@@ -166,14 +169,11 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
cli := s.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChMj, message.Message)
}
}()
}
@@ -211,14 +211,11 @@ func (s *Service) DownloadImages() {
v.ImgURL = imgURL
s.db.Updates(&v)
cli := s.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(service.TaskStatusFinished))
if err != nil {
continue
}
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[v.Id],
UserId: v.UserId,
JobId: int(v.Id),
Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 5)
@@ -264,7 +261,11 @@ func (s *Service) SyncTaskProgress() {
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed})
continue
}
@@ -289,7 +290,11 @@ func (s *Service) SyncTaskProgress() {
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: message})
}
}