From 8fffa60569de0e116d52b1c930a06c48fa334e45 Mon Sep 17 00:00:00 2001 From: RockYang Date: Fri, 27 Sep 2024 18:28:54 +0800 Subject: [PATCH] sd websocket refactor is finished --- api/core/types/task.go | 2 + api/handler/sd_handler.go | 14 +++---- api/service/sd/service.go | 19 ++++------ api/service/types.go | 7 ++-- web/src/App.vue | 10 ++--- web/src/views/ImageSd.vue | 67 ++++++++++------------------------ web/src/views/admin/ApiKey.vue | 2 +- 7 files changed, 45 insertions(+), 76 deletions(-) diff --git a/api/core/types/task.go b/api/core/types/task.go index d41e592a..6ec62167 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -43,12 +43,14 @@ type MjTask struct { type SdTask struct { Id int `json:"id"` // job 数据库ID Type TaskType `json:"type"` + ClientId string `json:"client_id"` UserId int `json:"user_id"` Params SdTaskParams `json:"params"` RetryCount int `json:"retry_count"` } type SdTaskParams struct { + ClientId string `json:"client_id"` // 客户端ID TaskId string `json:"task_id"` Prompt string `json:"prompt"` // 提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词 diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 2f658e73..0d2ba6ea 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -144,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) { } h.sdService.PushTask(types.SdTask{ - Id: int(job.Id), - Type: types.TaskImage, - Params: params, - UserId: userId, + Id: int(job.Id), + ClientId: data.ClientId, + Type: types.TaskImage, + Params: params, + UserId: userId, }) - client := h.sdService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - // update user's power err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, diff --git a/api/service/sd/service.go b/api/service/sd/service.go index a6c9a856..9aa25c2a 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -34,17 +34,17 @@ type Service struct { db *gorm.DB uploadManager *oss.UploaderManager leveldb *store.LevelDB - Clients *types.LMap[uint, *types.WsClient] // UserId => Client + wsService *service.WebsocketService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service { return &Service{ httpClient: req.C(), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), db: db, leveldb: levelDB, - Clients: types.NewLMap[uint, *types.WsClient](), + wsService: wsService, uploadManager: manager, } } @@ -90,7 +90,7 @@ func (s *Service) Run() { "err_msg": err.Error(), }) // 通知前端,任务失败 - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) continue } } @@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) // 从 leveldb 中删除预览图片数据 _ = s.leveldb.Delete(task.Params.TaskId) return nil @@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) // 保存预览图片数据 if resp.CurrentImage != "" { _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) @@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } - client := s.Clients.Get(uint(message.UserId)) + client := s.wsService.Clients.Get(message.ClientId) if client == nil { continue } - err = client.Send([]byte(message.Message)) - if err != nil { - continue - } + utils.SendChannelMsg(client, types.ChSd, message.Message) } }() } diff --git a/api/service/types.go b/api/service/types.go index 70f8eb92..1c5c601e 100644 --- a/api/service/types.go +++ b/api/service/types.go @@ -8,9 +8,10 @@ const ( ) type NotifyMessage struct { - UserId int `json:"user_id"` - JobId int `json:"job_id"` - Message string `json:"message"` + UserId int `json:"user_id"` + ClientId string `json:"client_id"` + JobId int `json:"job_id"` + Message string `json:"message"` } const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" diff --git a/web/src/App.vue b/web/src/App.vue index d999a96e..bfb50e0c 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -71,12 +71,12 @@ const connect = () => { handler.value = setInterval(() => { _socket.send(JSON.stringify({"type":"ping"})) },5000) + }) - for (const key in store.messageHandlers) { - console.log(key, store.messageHandlers[key]) - store.setMessageHandler(store.messageHandlers[key]) - } - }); + for (const key in store.messageHandlers) { + console.log(key, store.messageHandlers[key]) + store.setMessageHandler(store.messageHandlers[key]) + } _socket.addEventListener('close', () => { store.setSocket(null) diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 2b377022..a9309c93 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -487,11 +487,11 @@