From 1a1734abf09db9ea8dbc0984eae06684b84dc52c Mon Sep 17 00:00:00 2001 From: RockYang Date: Sun, 29 Sep 2024 19:28:47 +0800 Subject: [PATCH] websocket api refactor is ready --- CHANGELOG.md | 4 + api/core/types/task.go | 31 +++---- api/handler/dalle_handler.go | 20 ++--- api/handler/markmap_handler.go | 146 +++++++-------------------------- api/handler/sd_handler.go | 9 -- api/handler/suno_handler.go | 8 +- api/handler/video_handler.go | 19 ++--- api/handler/ws_handler.go | 2 +- api/service/dalle/service.go | 26 +++--- api/service/mj/service.go | 5 +- api/service/sd/service.go | 18 ++-- api/service/suno/service.go | 27 +++--- api/service/video/luma.go | 26 +++--- api/utils/openai.go | 21 +++-- database/update-v4.1.5.sql | 1 + web/src/views/Dalle.vue | 63 ++++---------- web/src/views/Luma.vue | 78 ++++-------------- web/src/views/MarkMap.vue | 111 ++++++------------------- web/src/views/Suno.vue | 59 ++++--------- 19 files changed, 210 insertions(+), 464 deletions(-) create mode 100644 database/update-v4.1.5.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c1bb9b1..88671e2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ # 更新日志 +## v4.1.5 +* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接 +* Bug修复:兼容手机端原生微信支付和支付宝支付渠道 +* Bug修复:修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题 ## v4.1.4 * 功能优化:用户文件列表组件增加分页功能支持 * Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码 diff --git a/api/core/types/task.go b/api/core/types/task.go index 900fd52e..5dac6443 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -72,18 +72,20 @@ type SdTaskParams struct { // DallTask DALL-E task type DallTask struct { - JobId uint `json:"job_id"` - UserId uint `json:"user_id"` - Prompt string `json:"prompt"` - N int `json:"n"` - Quality string `json:"quality"` - Size string `json:"size"` - Style string `json:"style"` + ClientId string `json:"client_id"` + JobId uint `json:"job_id"` + UserId uint `json:"user_id"` + Prompt string `json:"prompt"` + N int `json:"n"` + Quality string `json:"quality"` + Size string `json:"size"` + Style string `json:"style"` Power int `json:"power"` } type SunoTask struct { + ClientId string `json:"client_id"` Id uint `json:"id"` Channel string `json:"channel"` UserId int `json:"user_id"` @@ -107,13 +109,14 @@ const ( ) type VideoTask struct { - Id uint `json:"id"` - Channel string `json:"channel"` - UserId int `json:"user_id"` - Type string `json:"type"` - TaskId string `json:"task_id"` - Prompt string `json:"prompt"` // 提示词 - Params VideoParams `json:"params"` + ClientId string `json:"client_id"` + Id uint `json:"id"` + Channel string `json:"channel"` + UserId int `json:"user_id"` + Type string `json:"type"` + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + Params VideoParams `json:"params"` } type VideoParams struct { diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index 816086f6..eb46710f 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) { } h.dallService.PushTask(types.DallTask{ - JobId: job.Id, - UserId: uint(userId), - Prompt: data.Prompt, - Quality: data.Quality, - Size: data.Size, - Style: data.Style, - Power: job.Power, + ClientId: data.ClientId, + JobId: job.Id, + UserId: uint(userId), + Prompt: data.Prompt, + Quality: data.Quality, + Size: data.Size, + Style: data.Style, + Power: job.Power, }) - - client := h.dallService.Clients.Get(job.UserId) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } resp.SUCCESS(c) } diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 9337d996..e57f6f4d 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -8,23 +8,15 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( - "bufio" - "bytes" - "encoding/json" - "errors" "fmt" "geekai/core" "geekai/core/types" "geekai/service" "geekai/store/model" "geekai/utils" + "geekai/utils/resp" "github.com/gin-gonic/gin" "gorm.io/gorm" - "io" - "net/http" - "net/url" - "strings" - "time" ) // MarkMapHandler 生成思维导图 @@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us // Generate 生成思维导图 func (h *MarkMapHandler) Generate(c *gin.Context) { + var data struct { + Prompt string `json:"prompt"` + ModelId int `json:"model_id"` + } -} + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } -func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { + userId := h.GetLoginUserId(c) var user model.User - res := h.DB.Model(&model.User{}).First(&user, userId) - if res.Error != nil { - return fmt.Errorf("error with query user info: %v", res.Error) + err := h.DB.Where("id", userId).First(&user, userId).Error + if err != nil { + resp.ERROR(c, "error with query user info") + return } var chatModel model.ChatModel - res = h.DB.Where("id", modelId).First(&chatModel) - if res.Error != nil { - return fmt.Errorf("error with query chat model: %v", res.Error) + err = h.DB.Where("id", data.ModelId).First(&chatModel).Error + if err != nil { + resp.ERROR(c, "error with query chat model") + return } if user.Power < chatModel.Power { - return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) + resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)) + return } messages := make([]interface{}, 0) @@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode ### 支付宝 ### 微信 -另外,除此之外不要任何解释性语句。 +请直接生成结果,不要任何解释性语句。 `}) - messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)}) - var req = types.ApiRequest{ - Model: chatModel.Value, - Stream: true, - Messages: messages, - } - - var apiKey model.ApiKey - response, err := h.doRequest(req, chatModel, &apiKey) + messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)}) + content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId) if err != nil { - return fmt.Errorf("请求 OpenAI API 失败: %s", err) - } - - defer response.Body.Close() - - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - // 循环读取 Chunk 消息 - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if !strings.Contains(line, "data:") || len(line) < 30 { - continue - } - - var responseBody = types.ApiResponse{} - err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil { // 数据解析出错 - return fmt.Errorf("error with decode data: %v", line) - } - - if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 - continue - } - - if responseBody.Choices[0].FinishReason == "stop" { - break - } - - utils.SendMsg(client, types.ReplyMessage{ - Type: types.MsgTypeText, - Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), - }) - } // end for - - utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd}) - - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求 OpenAI API 失败:%s", string(body)) + resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err)) + return } // 扣减算力 if chatModel.Power > 0 { - err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{ + err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{ Type: types.PowerConsume, Model: chatModel.Value, Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), }) if err != nil { - return err + resp.ERROR(c, "error with save power log, "+err.Error()) + return } } - return nil -} - -func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { - - session := h.DB.Session(&gorm.Session{}) - // if the chat model bind a KEY, use it directly - if chatModel.KeyId > 0 { - session = session.Where("id", chatModel.KeyId) - } else { // use the last unused key - session = session.Where("type", "chat"). - Where("enabled", true).Order("last_used_at ASC") - } - - res := session.First(apiKey) - if res.Error != nil { - return nil, errors.New("no available key, please import key") - } - apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) - // 更新 API KEY 的最后使用时间 - h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - - // 创建 HttpClient 请求对象 - var client *http.Client - requestBody, err := json.Marshal(req) - if err != nil { - return nil, err - } - request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) - if err != nil { - return nil, err - } - - request.Header.Set("Content-Type", "application/json") - if len(apiKey.ProxyURL) > 5 { // 使用代理 - proxy, _ := url.Parse(apiKey.ProxyURL) - client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxy), - }, - } - } else { - client = http.DefaultClient - } - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) - logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model) - return client.Do(request) + resp.SUCCESS(c, content) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 0d2ba6ea..9f5345dd 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, if err != nil { continue } - - if item.Progress < 100 { - // 从 leveldb 中获取图片预览数据 - var imageData string - err = h.leveldb.Get(item.TaskId, &imageData) - if err == nil { - job.ImgURL = "data:image/png;base64," + imageData - } - } jobs = append(jobs, job) } diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 8df60385..d284c66f 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl func (h *SunoHandler) Create(c *gin.Context) { var data struct { + ClientId string `json:"client_id"` Prompt string `json:"prompt"` Instrumental bool `json:"instrumental"` Lyrics string `json:"lyrics"` @@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) { // 创建任务 h.sunoService.PushTask(types.SunoTask{ + ClientId: data.ClientId, Id: job.Id, UserId: job.UserId, Type: job.Type, @@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) { return } - client := h.sunoService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } resp.SUCCESS(c) } @@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0) if err != nil { resp.ERROR(c, err.Error()) return diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index f9a911ab..02bf21cb 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u func (h *VideoHandler) LumaCreate(c *gin.Context) { var data struct { + ClientId string `json:"client_id"` Prompt string `json:"prompt"` FirstFrameImg string `json:"first_frame_img,omitempty"` EndFrameImg string `json:"end_frame_img,omitempty"` @@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { // 创建任务 h.videoService.PushTask(types.VideoTask{ - Id: job.Id, - UserId: userId, - Type: types.VideoLuma, - Prompt: data.Prompt, - Params: params, + ClientId: data.ClientId, + Id: job.Id, + UserId: userId, + Type: types.VideoLuma, + Prompt: data.Prompt, + Params: params, }) // update user's power @@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { resp.ERROR(c, err.Error()) return } - - client := h.videoService.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } resp.SUCCESS(c) } @@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) { return } // 只有失败或者超时的任务才能删除 - if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) { + if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) { resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!") return } diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go index 05933116..1835f8ab 100644 --- a/api/handler/ws_handler.go +++ b/api/handler/ws_handler.go @@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) { continue } - logger.Infof("Receive a message:%+v", message) + logger.Debugf("Receive a message:%+v", message) if message.Type == types.MsgTypePing { utils.SendChannelMsg(client, types.ChPing, "pong") continue diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index d0732413..12bef395 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -34,19 +34,21 @@ type Service struct { uploadManager *oss.UploaderManager taskQueue *store.RedisQueue notifyQueue *store.RedisQueue - Clients *types.LMap[uint, *types.WsClient] // UserId => Client userService *service.UserService + wsService *service.WebsocketService + clientIds map[uint]string } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), - Clients: types.NewLMap[uint, *types.WsClient](), + wsService: wsService, uploadManager: manager, userService: userService, + clientIds: map[uint]string{}, } } @@ -67,6 +69,7 @@ func (s *Service) Run() { continue } logger.Infof("handle a new DALL-E task: %+v", task) + s.clientIds[task.JobId] = task.ClientId _, err = s.Image(task, false) if err != nil { logger.Errorf("error with image task: %v", err) @@ -74,7 +77,7 @@ func (s *Service) Run() { "progress": service.FailTaskProgress, "err_msg": err.Error(), }) - s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) } } }() @@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { prompt := task.Prompt // translate prompt if utils.HasChinese(prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0) if err == nil { prompt = content logger.Debugf("重写后提示词:%s", prompt) @@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { return "", fmt.Errorf("err with update database: %v", err) } - s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) var content string if sync { imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) @@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } - client := s.Clients.Get(uint(message.UserId)) + + logger.Debugf("notify message: %+v", message) + client := s.wsService.Clients.Get(message.ClientId) if client == nil { continue } - err = client.Send([]byte(message.Message)) - if err != nil { - continue - } + utils.SendChannelMsg(client, types.ChDall, message.Message) } }() } @@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, if res.Error != nil { return "", err } - s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished}) return imgURL, nil } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 8086318a..72192dcf 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -58,7 +58,7 @@ func (s *Service) Run() { // translate prompt if utils.HasChinese(task.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0) if err == nil { task.Prompt = content } else { @@ -67,7 +67,7 @@ func (s *Service) Run() { } // translate negative prompt if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0) if err == nil { task.NegPrompt = content } else { @@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } + logger.Debugf("receive a new mj notify message: %+v", message) client := s.wsService.Clients.Get(message.ClientId) if client == nil { continue diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 9aa25c2a..9bfd1ecd 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -33,7 +33,6 @@ type Service struct { notifyQueue *store.RedisQueue db *gorm.DB uploadManager *oss.UploaderManager - leveldb *store.LevelDB wsService *service.WebsocketService } @@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), db: db, - leveldb: levelDB, wsService: wsService, uploadManager: manager, } @@ -62,7 +60,7 @@ func (s *Service) Run() { // translate prompt if utils.HasChinese(task.Params.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0) if err == nil { task.Params.Prompt = content } else { @@ -72,7 +70,7 @@ func (s *Service) Run() { // translate negative prompt if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0) if err == nil { task.Params.NegPrompt = content } else { @@ -126,9 +124,8 @@ type Txt2ImgResp struct { // TaskProgressResp 任务进度响应实体 type TaskProgressResp struct { - Progress float64 `json:"progress"` - EtaRelative float64 `json:"eta_relative"` - CurrentImage string `json:"current_image"` + Progress float64 `json:"progress"` + EtaRelative float64 `json:"eta_relative"` } // Txt2Img 文生图 API @@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error { // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) - // 从 leveldb 中删除预览图片数据 - _ = s.leveldb.Delete(task.Params.TaskId) return nil default: err, resp := s.checkTaskProgress(apiKey) @@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) - // 保存预览图片数据 - if resp.CurrentImage != "" { - _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) - } } time.Sleep(time.Second) } @@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } + logger.Debugf("notify message: %+v", message) client := s.wsService.Clients.Get(message.ClientId) if client == nil { continue diff --git a/api/service/suno/service.go b/api/service/suno/service.go index e3e502dd..9e293b3e 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -34,17 +34,19 @@ type Service struct { uploadManager *oss.UploaderManager taskQueue *store.RedisQueue notifyQueue *store.RedisQueue - Clients *types.LMap[uint, *types.WsClient] // UserId => Client + wsService *service.WebsocketService + clientIds map[string]string } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli), - Clients: types.NewLMap[uint, *types.WsClient](), uploadManager: manager, + wsService: wsService, + clientIds: map[string]string{}, } } @@ -96,7 +98,7 @@ func (s *Service) Run() { "err_msg": err.Error(), "progress": service.FailTaskProgress, }) - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) continue } @@ -105,6 +107,7 @@ func (s *Service) Run() { "task_id": r.Data, "channel": r.Channel, }) + s.clientIds[r.Data] = task.ClientId } }() } @@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } - client := s.Clients.Get(uint(message.UserId)) + logger.Debugf("notify message: %+v", message) + logger.Debugf("client id: %+v", s.wsService.Clients) + client := s.wsService.Clients.Get(message.ClientId) + logger.Debugf("%+v", client) if client == nil { continue } - err = client.Send([]byte(message.Message)) - if err != nil { - continue - } + utils.SendChannelMsg(client, types.ChSuno, message.Message) } }() } @@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() { v.AudioURL = audioURL v.Progress = 100 s.db.Updates(&v) - s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) } time.Sleep(time.Second * 10) @@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() { } } tx.Commit() - + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished}) } else if task.Data.FailReason != "" { job.Progress = service.FailTaskProgress job.ErrMsg = task.Data.FailReason s.db.Updates(&job) - s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) } } diff --git a/api/service/video/luma.go b/api/service/video/luma.go index 2b1f250d..144269f7 100644 --- a/api/service/video/luma.go +++ b/api/service/video/luma.go @@ -34,17 +34,19 @@ type Service struct { uploadManager *oss.UploaderManager taskQueue *store.RedisQueue notifyQueue *store.RedisQueue - Clients *types.LMap[uint, *types.WsClient] // UserId => Client + wsService *service.WebsocketService + clientIds map[uint]string } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), - Clients: types.NewLMap[uint, *types.WsClient](), + wsService: wsService, uploadManager: manager, + clientIds: map[uint]string{}, } } @@ -85,7 +87,7 @@ func (s *Service) Run() { // translate prompt if utils.HasChinese(task.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini") + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0) if err == nil { task.Prompt = content } else { @@ -93,6 +95,10 @@ func (s *Service) Run() { } } + if task.ClientId != "" { + s.clientIds[task.Id] = task.ClientId + } + var r LumaRespVo r, err = s.LumaCreate(task) if err != nil { @@ -105,7 +111,7 @@ func (s *Service) Run() { if err != nil { logger.Errorf("update task with error: %v", err) } - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) continue } @@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() { if err != nil { continue } - client := s.Clients.Get(uint(message.UserId)) + logger.Debugf("Receive notify message: %+v", message) + client := s.wsService.Clients.Get(message.ClientId) if client == nil { continue } - err = client.Send([]byte(message.Message)) - if err != nil { - continue - } + utils.SendChannelMsg(client, types.ChLuma, message.Message) } }() } @@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() { v.VideoURL = videoURL v.Progress = 100 s.db.Updates(&v) - s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) } time.Sleep(time.Second * 10) diff --git a/api/utils/openai.go b/api/utils/openai.go index c9d7363a..3c1e4f15 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -45,18 +45,25 @@ type apiRes struct { } `json:"choices"` } -func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) { - var apiKey model.ApiKey - res := db.Where("type", "chat").Where("enabled", true).First(&apiKey) - if res.Error != nil { - return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) - } - +func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) { messages := make([]interface{}, 1) messages[0] = types.Message{ Role: "user", Content: prompt, } + return SendOpenAIMessage(db, messages, modelName, keyId) +} + +func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) { + var apiKey model.ApiKey + session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true) + if keyId > 0 { + session = session.Where("id", keyId) + } + err := session.First(&apiKey).Error + if err != nil { + return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err) + } var response apiRes client := req.C() diff --git a/database/update-v4.1.5.sql b/database/update-v4.1.5.sql new file mode 100644 index 00000000..4ab40b95 --- /dev/null +++ b/database/update-v4.1.5.sql @@ -0,0 +1 @@ +ALTER TABLE `chatgpt_power_logs` CHANGE `remark` `remark` VARCHAR(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '备注'; \ No newline at end of file diff --git a/web/src/views/Dalle.vue b/web/src/views/Dalle.vue index bb88bae2..fc7a54dd 100644 --- a/web/src/views/Dalle.vue +++ b/web/src/views/Dalle.vue @@ -208,7 +208,7 @@ import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue"; import {httpGet, httpPost} from "@/utils/http"; import {ElMessage, ElMessageBox} from "element-plus"; import Clipboard from "clipboard"; -import {checkSession, getSystemInfo} from "@/store/cache"; +import {checkSession, getClientId, getSystemInfo} from "@/store/cache"; import {useSharedStore} from "@/store/sharedata"; import TaskList from "@/components/TaskList.vue"; import BackTop from "@/components/BackTop.vue"; @@ -240,6 +240,7 @@ const styles = [ {name: "自然", value: "natural"} ] const params = ref({ + client_id: getClientId(), quality: "standard", size: "1024x1024", style: "vivid", @@ -268,14 +269,24 @@ onMounted(() => { }).catch(e => { ElMessage.error("获取系统配置失败:" + e.message) }) + + store.addMessageHandler("dall",(data) => { + // 丢弃无关消息 + if (data.channel !== "dall" || data.clientId !== getClientId()) { + return + } + + if (data.body === "FINISH" || data.body === "FAIL") { + page.value = 0 + isOver.value = false + fetchFinishJobs() + } + nextTick(() => fetchRunningJobs()) + }) }) onUnmounted(() => { clipboard.value.destroy() - if (socket.value !== null) { - socket.value.close() - socket.value = null - } }) const initData = () => { @@ -287,51 +298,10 @@ const initData = () => { page.value = 0 fetchRunningJobs() fetchFinishJobs() - connect() }).catch(() => { }); } -const socket = ref(null) -const heartbeatHandle = ref(null) -const connect = () => { - let host = process.env.VUE_APP_WS_HOST - if (host === '') { - if (location.protocol === 'https:') { - host = 'wss://' + location.host; - } else { - host = 'ws://' + location.host; - } - } - - const _socket = new WebSocket(host + `/api/dall/client?user_id=${userId.value}`); - _socket.addEventListener('open', () => { - socket.value = _socket; - }); - - _socket.addEventListener('message', event => { - if (event.data instanceof Blob) { - const reader = new FileReader(); - reader.readAsText(event.data, "UTF-8") - reader.onload = () => { - const message = String(reader.result) - if (message === "FINISH" || message === "FAIL") { - page.value = 0 - isOver.value = false - fetchFinishJobs(page.value) - } - nextTick(() => fetchRunningJobs()) - } - } - }); - - _socket.addEventListener('close', () => { - if (socket.value !== null) { - connect() - } - }) -} - const fetchRunningJobs = () => { if (!isLogin.value) { return @@ -391,6 +361,7 @@ const generate = () => { httpPost("/api/dall/image", params.value).then(() => { ElMessage.success("任务执行成功!") power.value -= dallPower.value + fetchRunningJobs() }).catch(e => { ElMessage.error("任务执行失败:" + e.message) }) diff --git a/web/src/views/Luma.vue b/web/src/views/Luma.vue index afd51a1c..089f41f8 100644 --- a/web/src/views/Luma.vue +++ b/web/src/views/Luma.vue @@ -55,25 +55,6 @@

你的作品

- - - - - - - - - - - - - - - - - - -
@@ -153,13 +134,14 @@ import {onMounted, reactive, ref} from "vue"; import {CircleCloseFilled} from "@element-plus/icons-vue"; import {httpDownload, httpPost, httpGet} from "@/utils/http"; -import {checkSession} from "@/store/cache"; +import {checkSession, getClientId} from "@/store/cache"; import {showMessageError, showMessageOK} from "@/utils/dialog"; import { replaceImg } from "@/utils/libs" import {ElMessage, ElMessageBox} from "element-plus"; import BlackSwitch from "@/components/ui/BlackSwitch.vue"; import Generating from "@/components/ui/Generating.vue"; import BlackDialog from "@/components/ui/BlackDialog.vue"; +import {useSharedStore} from "@/store/sharedata"; const showDialog = ref(false) const currentVideoUrl = ref('') @@ -167,6 +149,7 @@ const row = ref(1) const images = ref([]) const formData = reactive({ + client_id: getClientId(), prompt: '', expand_prompt: false, loop: false, @@ -174,49 +157,22 @@ const formData = reactive({ end_frame_img: '' }) -const socket = ref(null) -const userId = ref(0) -const connect = () => { - let host = process.env.VUE_APP_WS_HOST - if (host === '') { - if (location.protocol === 'https:') { - host = 'wss://' + location.host; - } else { - host = 'ws://' + location.host; - } - } - - const _socket = new WebSocket(host + `/api/video/client?user_id=${userId.value}`); - _socket.addEventListener('open', () => { - socket.value = _socket; - }); - - _socket.addEventListener('message', event => { - if (event.data instanceof Blob) { - const reader = new FileReader(); - reader.readAsText(event.data, "UTF-8") - reader.onload = () => { - const message = String(reader.result) - if (message === "FINISH" || message === "FAIL") { - fetchData() - } - } - } - }); - - _socket.addEventListener('close', () => { - if (socket.value !== null) { - connect() - } - }); -} - +const store = useSharedStore() onMounted(()=>{ - checkSession().then(user => { - userId.value = user.id - connect() + checkSession().then(() => { + fetchData(1) + }) + + store.addMessageHandler("luma",(data) => { + // 丢弃无关消息 + if (data.channel !== "luma" || data.clientId !== getClientId()) { + return + } + + if (data.body === "FINISH" || data.body === "FAIL") { + fetchData(1) + } }) - fetchData(1) }) const download = (item) => { diff --git a/web/src/views/MarkMap.vue b/web/src/views/MarkMap.vue index d2bef90c..927a2f2b 100644 --- a/web/src/views/MarkMap.vue +++ b/web/src/views/MarkMap.vue @@ -45,7 +45,7 @@
- 智能生成思维导图 + 生成思维导图
@@ -79,10 +79,7 @@
-
-
-
-
+
@@ -94,11 +91,11 @@