From 8093a3eeb2216caff11e6f8e91371bad766d6439 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sun, 29 Sep 2024 07:51:08 +0800 Subject: [PATCH] mj websocket refactor is ready --- api/core/types/task.go | 5 +- api/handler/mj_handler.go | 34 ++----------- api/service/mj/service.go | 45 +++++++++-------- web/.env.development | 2 +- web/src/components/TaskList.vue | 30 +++-------- web/src/views/ChatPlus.vue | 7 --- web/src/views/ImageMj.vue | 89 +++++++++------------------------ 7 files changed, 63 insertions(+), 149 deletions(-) diff --git a/api/core/types/task.go b/api/core/types/task.go index 6ec62167..900fd52e 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -24,8 +24,9 @@ const ( // MjTask MidJourney 任务 type MjTask struct { - Id uint `json:"id"` - TaskId string `json:"task_id"` + Id uint `json:"id"` // 任务ID + TaskId string `json:"task_id"` // 中转任务ID + ClientId string `json:"client_id"` ImgArr []string `json:"img_arr"` Type TaskType `json:"type"` UserId int `json:"user_id"` diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index c758032b..8feaeb53 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -8,7 +8,6 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( - "encoding/base64" "fmt" "geekai/core" "geekai/core/types" @@ -67,6 +66,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { TaskType string `json:"task_type"` + ClientId string `json:"client_id"` Prompt string `json:"prompt"` NegPrompt string `json:"neg_prompt"` Rate string `json:"rate"` @@ -177,6 +177,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { h.mjService.PushTask(types.MjTask{ Id: job.Id, + ClientId: data.ClientId, TaskId: taskId, Type: types.TaskType(data.TaskType), Prompt: data.Prompt, @@ -187,11 +188,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.mjService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - // update user's power err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, @@ -208,6 +204,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { type reqVo struct { Index int `json:"index"` + ClientId string `json:"client_id"` ChannelId string `json:"channel_id"` MessageId string `json:"message_id"` MessageHash string `json:"message_hash"` @@ -244,6 +241,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { h.mjService.PushTask(types.MjTask{ Id: job.Id, + ClientId: data.ClientId, Type: types.TaskUpscale, UserId: userId, ChannelId: data.ChannelId, @@ -253,11 +251,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.mjService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - // update user's power err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, @@ -305,6 +298,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { h.mjService.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskVariation, + ClientId: data.ClientId, UserId: userId, Index: data.Index, ChannelId: data.ChannelId, @@ -313,11 +307,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.mjService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, Model: "mid-journey", @@ -397,14 +386,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize if err != nil { continue } - - if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { - image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) - if err == nil { - job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - jobs = append(jobs, job) } return nil, vo.NewPage(total, page, pageSize, jobs) @@ -449,11 +430,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } - client := h.mjService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - resp.SUCCESS(c) } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 4c9059fa..8086318a 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -28,18 +28,20 @@ type Service struct { taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB - Clients *types.LMap[uint, *types.WsClient] // UserId => Client + wsService *service.WebsocketService uploaderManager *oss.UploaderManager + clientIds map[uint]string } -func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service { +func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service { return &Service{ db: db, taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli), client: client, - Clients: types.NewLMap[uint, *types.WsClient](), + wsService: wsService, uploaderManager: manager, + clientIds: map[uint]string{}, } } @@ -77,6 +79,7 @@ func (s *Service) Run() { if task.Mode == "" { task.Mode = "fast" } + s.clientIds[task.Id] = task.ClientId var job model.MidJourneyJob tx := s.db.Where("id = ?", task.Id).First(&job) @@ -119,7 +122,7 @@ func (s *Service) Run() { // update the task progress s.db.Updates(&job) // 任务失败,通知前端 - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) continue } logger.Infof("任务提交成功:%+v", res) @@ -166,14 +169,11 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } - cli := s.Clients.Get(uint(message.UserId)) - if cli == nil { - continue - } - err = cli.Send([]byte(message.Message)) - if err != nil { + client := s.wsService.Clients.Get(message.ClientId) + if client == nil { continue } + utils.SendChannelMsg(client, types.ChMj, message.Message) } }() } @@ -211,14 +211,11 @@ func (s *Service) DownloadImages() { v.ImgURL = imgURL s.db.Updates(&v) - cli := s.Clients.Get(uint(v.UserId)) - if cli == nil { - continue - } - err = cli.Send([]byte(service.TaskStatusFinished)) - if err != nil { - continue - } + s.notifyQueue.RPush(service.NotifyMessage{ + ClientId: s.clientIds[v.Id], + UserId: v.UserId, + JobId: int(v.Id), + Message: service.TaskStatusFinished}) } time.Sleep(time.Second * 5) @@ -264,7 +261,11 @@ func (s *Service) SyncTaskProgress() { "err_msg": task.FailReason, }) logger.Errorf("task failed: %v", task.FailReason) - s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ + ClientId: s.clientIds[job.Id], + UserId: job.UserId, + JobId: int(job.Id), + Message: service.TaskStatusFailed}) continue } @@ -289,7 +290,11 @@ func (s *Service) SyncTaskProgress() { if job.Progress == 100 { message = service.TaskStatusFinished } - s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) + s.notifyQueue.RPush(service.NotifyMessage{ + ClientId: s.clientIds[job.Id], + UserId: job.UserId, + JobId: int(job.Id), + Message: message}) } } diff --git a/web/.env.development b/web/.env.development index 97c67f7e..99693506 100644 --- a/web/.env.development +++ b/web/.env.development @@ -6,6 +6,6 @@ VUE_APP_ADMIN_USER=admin VUE_APP_ADMIN_PASS=admin123 VUE_APP_KEY_PREFIX=GeekAI_DEV_ VUE_APP_TITLE="Geek-AI 创作系统" -VUE_APP_VERSION=v4.1.4 +VUE_APP_VERSION=v4.1.5 VUE_APP_DOCS_URL=https://docs.geekai.me VUE_APP_GIT_URL=https://github.com/yangjian102621/geekai diff --git a/web/src/components/TaskList.vue b/web/src/components/TaskList.vue index d2e63e0d..86ced83a 100644 --- a/web/src/components/TaskList.vue +++ b/web/src/components/TaskList.vue @@ -1,28 +1,14 @@