websocket api refactor is ready

This commit is contained in:
RockYang 2024-09-29 19:28:47 +08:00
parent 8093a3eeb2
commit 1a1734abf0
19 changed files with 210 additions and 464 deletions

View File

@ -1,4 +1,8 @@
# 更新日志 # 更新日志
## v4.1.5
* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
* Bug修复兼容手机端原生微信支付和支付宝支付渠道
* Bug修复修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题
## v4.1.4 ## v4.1.4
* 功能优化:用户文件列表组件增加分页功能支持 * 功能优化:用户文件列表组件增加分页功能支持
* Bug修复修复用户注册失败Bug注册操作只弹出一次行为验证码 * Bug修复修复用户注册失败Bug注册操作只弹出一次行为验证码

View File

@ -72,18 +72,20 @@ type SdTaskParams struct {
// DallTask DALL-E task // DallTask DALL-E task
type DallTask struct { type DallTask struct {
JobId uint `json:"job_id"` ClientId string `json:"client_id"`
UserId uint `json:"user_id"` JobId uint `json:"job_id"`
Prompt string `json:"prompt"` UserId uint `json:"user_id"`
N int `json:"n"` Prompt string `json:"prompt"`
Quality string `json:"quality"` N int `json:"n"`
Size string `json:"size"` Quality string `json:"quality"`
Style string `json:"style"` Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"` Power int `json:"power"`
} }
type SunoTask struct { type SunoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"` Id uint `json:"id"`
Channel string `json:"channel"` Channel string `json:"channel"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@ -107,13 +109,14 @@ const (
) )
type VideoTask struct { type VideoTask struct {
Id uint `json:"id"` ClientId string `json:"client_id"`
Channel string `json:"channel"` Id uint `json:"id"`
UserId int `json:"user_id"` Channel string `json:"channel"`
Type string `json:"type"` UserId int `json:"user_id"`
TaskId string `json:"task_id"` Type string `json:"type"`
Prompt string `json:"prompt"` // 提示词 TaskId string `json:"task_id"`
Params VideoParams `json:"params"` Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
} }
type VideoParams struct { type VideoParams struct {

View File

@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
} }
h.dallService.PushTask(types.DallTask{ h.dallService.PushTask(types.DallTask{
JobId: job.Id, ClientId: data.ClientId,
UserId: uint(userId), JobId: job.Id,
Prompt: data.Prompt, UserId: uint(userId),
Quality: data.Quality, Prompt: data.Prompt,
Size: data.Size, Quality: data.Quality,
Style: data.Style, Size: data.Size,
Power: job.Power, Style: data.Style,
Power: job.Power,
}) })
client := h.dallService.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@ -8,23 +8,15 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt" "fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
) )
// MarkMapHandler 生成思维导图 // MarkMapHandler 生成思维导图
@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
// Generate 生成思维导图 // Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) { func (h *MarkMapHandler) Generate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
ModelId int `json:"model_id"`
}
} if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { userId := h.GetLoginUserId(c)
var user model.User var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId) err := h.DB.Where("id", userId).First(&user, userId).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query user info: %v", res.Error) resp.ERROR(c, "error with query user info")
return
} }
var chatModel model.ChatModel var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel) err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if res.Error != nil { if err != nil {
return fmt.Errorf("error with query chat model: %v", res.Error) resp.ERROR(c, "error with query chat model")
return
} }
if user.Power < chatModel.Power { if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power) resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power))
return
} }
messages := make([]interface{}, 0) messages := make([]interface{}, 0)
@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
### 支付宝 ### 支付宝
### 微信 ### 微信
另外除此之外不要任何解释性语句 请直接生成结果不要任何解释性语句
`}) `})
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", prompt)}) messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", data.Prompt)})
var req = types.ApiRequest{ content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
if err != nil { if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err) resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
} return
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
utils.SendMsg(client, types.ReplyMessage{
Type: types.MsgTypeText,
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
} }
// 扣减算力 // 扣减算力
if chatModel.Power > 0 { if chatModel.Power > 0 {
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{ err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,
Model: chatModel.Value, Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value), Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
}) })
if err != nil { if err != nil {
return err resp.ERROR(c, "error with save power log, "+err.Error())
return
} }
} }
return nil resp.SUCCESS(c, content)
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
if chatModel.KeyId > 0 {
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
return client.Do(request)
} }

View File

@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
if err != nil { if err != nil {
continue continue
} }
if item.Progress < 100 {
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job) jobs = append(jobs, job)
} }

View File

@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
func (h *SunoHandler) Create(c *gin.Context) { func (h *SunoHandler) Create(c *gin.Context) {
var data struct { var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"` Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"` Lyrics string `json:"lyrics"`
@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
// 创建任务 // 创建任务
h.sunoService.PushTask(types.SunoTask{ h.sunoService.PushTask(types.SunoTask{
ClientId: data.ClientId,
Id: job.Id, Id: job.Id,
UserId: job.UserId, UserId: job.UserId,
Type: job.Type, Type: job.Type,
@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
return return
} }
client := h.sunoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return

View File

@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
func (h *VideoHandler) LumaCreate(c *gin.Context) { func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct { var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"` FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"` EndFrameImg string `json:"end_frame_img,omitempty"`
@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
// 创建任务 // 创建任务
h.videoService.PushTask(types.VideoTask{ h.videoService.PushTask(types.VideoTask{
Id: job.Id, ClientId: data.ClientId,
UserId: userId, Id: job.Id,
Type: types.VideoLuma, UserId: userId,
Prompt: data.Prompt, Type: types.VideoLuma,
Params: params, Prompt: data.Prompt,
Params: params,
}) })
// update user's power // update user's power
@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
client := h.videoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) {
return return
} }
// 只有失败或者超时的任务才能删除 // 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) { if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!") resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return return
} }

View File

@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
continue continue
} }
logger.Infof("Receive a message:%+v", message) logger.Debugf("Receive a message:%+v", message)
if message.Type == types.MsgTypePing { if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong") utils.SendChannelMsg(client, types.ChPing, "pong")
continue continue

View File

@ -34,19 +34,21 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
userService *service.UserService userService *service.UserService
wsService *service.WebsocketService
clientIds map[uint]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
userService: userService, userService: userService,
clientIds: map[uint]string{},
} }
} }
@ -67,6 +69,7 @@ func (s *Service) Run() {
continue continue
} }
logger.Infof("handle a new DALL-E task: %+v", task) logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.JobId] = task.ClientId
_, err = s.Image(task, false) _, err = s.Image(task, false)
if err != nil { if err != nil {
logger.Errorf("error with image task: %v", err) logger.Errorf("error with image task: %v", err)
@ -74,7 +77,7 @@ func (s *Service) Run() {
"progress": service.FailTaskProgress, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
} }
} }
}() }()
@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
prompt := task.Prompt prompt := task.Prompt
// translate prompt // translate prompt
if utils.HasChinese(prompt) { if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
prompt = content prompt = content
logger.Debugf("重写后提示词:%s", prompt) logger.Debugf("重写后提示词:%s", prompt)
@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", fmt.Errorf("err with update database: %v", err) return "", fmt.Errorf("err with update database: %v", err)
} }
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChDall, message.Message)
if err != nil {
continue
}
} }
}() }()
} }
@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil { if res.Error != nil {
return "", err return "", err
} }
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil return imgURL, nil
} }

View File

@ -58,7 +58,7 @@ func (s *Service) Run() {
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.Prompt = content task.Prompt = content
} else { } else {
@ -67,7 +67,7 @@ func (s *Service) Run() {
} }
// translate negative prompt // translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.NegPrompt = content task.NegPrompt = content
} else { } else {
@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
logger.Debugf("receive a new mj notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId) client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue

View File

@ -33,7 +33,6 @@ type Service struct {
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
leveldb *store.LevelDB
wsService *service.WebsocketService wsService *service.WebsocketService
} }
@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB,
wsService: wsService, wsService: wsService,
uploadManager: manager, uploadManager: manager,
} }
@ -62,7 +60,7 @@ func (s *Service) Run() {
// translate prompt // translate prompt
if utils.HasChinese(task.Params.Prompt) { if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.Params.Prompt = content task.Params.Prompt = content
} else { } else {
@ -72,7 +70,7 @@ func (s *Service) Run() {
// translate negative prompt // translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.Params.NegPrompt = content task.Params.NegPrompt = content
} else { } else {
@ -126,9 +124,8 @@ type Txt2ImgResp struct {
// TaskProgressResp 任务进度响应实体 // TaskProgressResp 任务进度响应实体
type TaskProgressResp struct { type TaskProgressResp struct {
Progress float64 `json:"progress"` Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"` EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
} }
// Txt2Img 文生图 API // Txt2Img 文生图 API
@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
default: default:
err, resp := s.checkTaskProgress(apiKey) err, resp := s.checkTaskProgress(apiKey)
@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId) client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue

View File

@ -34,17 +34,19 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client wsService *service.WebsocketService
clientIds map[string]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager, uploadManager: manager,
wsService: wsService,
clientIds: map[string]string{},
} }
} }
@ -96,7 +98,7 @@ func (s *Service) Run() {
"err_msg": err.Error(), "err_msg": err.Error(),
"progress": service.FailTaskProgress, "progress": service.FailTaskProgress,
}) })
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue continue
} }
@ -105,6 +107,7 @@ func (s *Service) Run() {
"task_id": r.Data, "task_id": r.Data,
"channel": r.Channel, "channel": r.Channel,
}) })
s.clientIds[r.Data] = task.ClientId
} }
}() }()
} }
@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId)) logger.Debugf("notify message: %+v", message)
logger.Debugf("client id: %+v", s.wsService.Clients)
client := s.wsService.Clients.Get(message.ClientId)
logger.Debugf("%+v", client)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChSuno, message.Message)
if err != nil {
continue
}
} }
}() }()
} }
@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
v.AudioURL = audioURL v.AudioURL = audioURL
v.Progress = 100 v.Progress = 100
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
} }
} }
tx.Commit() tx.Commit()
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
} else if task.Data.FailReason != "" { } else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason job.ErrMsg = task.Data.FailReason
s.db.Updates(&job) s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
} }
} }

View File

@ -34,17 +34,19 @@ type Service struct {
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client wsService *service.WebsocketService
clientIds map[uint]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
clientIds: map[uint]string{},
} }
} }
@ -85,7 +87,7 @@ func (s *Service) Run() {
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil { if err == nil {
task.Prompt = content task.Prompt = content
} else { } else {
@ -93,6 +95,10 @@ func (s *Service) Run() {
} }
} }
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
var r LumaRespVo var r LumaRespVo
r, err = s.LumaCreate(task) r, err = s.LumaCreate(task)
if err != nil { if err != nil {
@ -105,7 +111,7 @@ func (s *Service) Run() {
if err != nil { if err != nil {
logger.Errorf("update task with error: %v", err) logger.Errorf("update task with error: %v", err)
} }
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue continue
} }
@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId)) logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChLuma, message.Message)
if err != nil {
continue
}
} }
}() }()
} }
@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() {
v.VideoURL = videoURL v.VideoURL = videoURL
v.Progress = 100 v.Progress = 100
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)

View File

@ -45,18 +45,25 @@ type apiRes struct {
} `json:"choices"` } `json:"choices"`
} }
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) { func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
var apiKey model.ApiKey
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}
messages := make([]interface{}, 1) messages := make([]interface{}, 1)
messages[0] = types.Message{ messages[0] = types.Message{
Role: "user", Role: "user",
Content: prompt, Content: prompt,
} }
return SendOpenAIMessage(db, messages, modelName, keyId)
}
func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
var apiKey model.ApiKey
session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
if keyId > 0 {
session = session.Where("id", keyId)
}
err := session.First(&apiKey).Error
if err != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", err)
}
var response apiRes var response apiRes
client := req.C() client := req.C()

View File

@ -0,0 +1 @@
ALTER TABLE `chatgpt_power_logs` CHANGE `remark` `remark` VARCHAR(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '备注';

View File

@ -208,7 +208,7 @@ import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession, getSystemInfo} from "@/store/cache"; import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
import {useSharedStore} from "@/store/sharedata"; import {useSharedStore} from "@/store/sharedata";
import TaskList from "@/components/TaskList.vue"; import TaskList from "@/components/TaskList.vue";
import BackTop from "@/components/BackTop.vue"; import BackTop from "@/components/BackTop.vue";
@ -240,6 +240,7 @@ const styles = [
{name: "自然", value: "natural"} {name: "自然", value: "natural"}
] ]
const params = ref({ const params = ref({
client_id: getClientId(),
quality: "standard", quality: "standard",
size: "1024x1024", size: "1024x1024",
style: "vivid", style: "vivid",
@ -268,14 +269,24 @@ onMounted(() => {
}).catch(e => { }).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message) ElMessage.error("获取系统配置失败:" + e.message)
}) })
store.addMessageHandler("dall",(data) => {
//
if (data.channel !== "dall" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
})
}) })
onUnmounted(() => { onUnmounted(() => {
clipboard.value.destroy() clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
}) })
const initData = () => { const initData = () => {
@ -287,51 +298,10 @@ const initData = () => {
page.value = 0 page.value = 0
fetchRunningJobs() fetchRunningJobs()
fetchFinishJobs() fetchFinishJobs()
connect()
}).catch(() => { }).catch(() => {
}); });
} }
const socket = ref(null)
const heartbeatHandle = ref(null)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/dall/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs(page.value)
}
nextTick(() => fetchRunningJobs())
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
})
}
const fetchRunningJobs = () => { const fetchRunningJobs = () => {
if (!isLogin.value) { if (!isLogin.value) {
return return
@ -391,6 +361,7 @@ const generate = () => {
httpPost("/api/dall/image", params.value).then(() => { httpPost("/api/dall/image", params.value).then(() => {
ElMessage.success("任务执行成功!") ElMessage.success("任务执行成功!")
power.value -= dallPower.value power.value -= dallPower.value
fetchRunningJobs()
}).catch(e => { }).catch(e => {
ElMessage.error("任务执行失败:" + e.message) ElMessage.error("任务执行失败:" + e.message)
}) })

View File

@ -55,25 +55,6 @@
<el-container class="video-container" v-loading="loading" element-loading-background="rgba(100,100,100,0.3)"> <el-container class="video-container" v-loading="loading" element-loading-background="rgba(100,100,100,0.3)">
<h2 class="h-title">你的作品</h2> <h2 class="h-title">你的作品</h2>
<!-- <el-row :gutter="20" class="videos" v-if="!noData">-->
<!-- <el-col :span="8" class="item" :key="item.id" v-for="item in videos">-->
<!-- <div class="video-box" @mouseover="item.playing = true" @mouseout="item.playing = false">-->
<!-- <img :src="item.cover" :alt="item.name" v-show="!item.playing"/>-->
<!-- <video :src="item.url" preload="auto" :autoplay="true" loop="loop" muted="muted" v-show="item.playing">-->
<!-- 您的浏览器不支持视频播放-->
<!-- </video>-->
<!-- </div>-->
<!-- <div class="video-name">{{item.name}}</div>-->
<!-- <div class="opts">-->
<!-- <button class="btn" @click="download(item)" :disabled="item.downloading">-->
<!-- <i class="iconfont icon-download" v-if="!item.downloading"></i>-->
<!-- <el-image src="/images/loading.gif" fit="cover" v-else />-->
<!-- <span>下载</span>-->
<!-- </button>-->
<!-- </div>-->
<!-- </el-col>-->
<!-- </el-row>-->
<div class="list-box" v-if="!noData"> <div class="list-box" v-if="!noData">
<div v-for="item in list" :key="item.id"> <div v-for="item in list" :key="item.id">
<div class="item"> <div class="item">
@ -153,13 +134,14 @@
import {onMounted, reactive, ref} from "vue"; import {onMounted, reactive, ref} from "vue";
import {CircleCloseFilled} from "@element-plus/icons-vue"; import {CircleCloseFilled} from "@element-plus/icons-vue";
import {httpDownload, httpPost, httpGet} from "@/utils/http"; import {httpDownload, httpPost, httpGet} from "@/utils/http";
import {checkSession} from "@/store/cache"; import {checkSession, getClientId} from "@/store/cache";
import {showMessageError, showMessageOK} from "@/utils/dialog"; import {showMessageError, showMessageOK} from "@/utils/dialog";
import { replaceImg } from "@/utils/libs" import { replaceImg } from "@/utils/libs"
import {ElMessage, ElMessageBox} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import BlackSwitch from "@/components/ui/BlackSwitch.vue"; import BlackSwitch from "@/components/ui/BlackSwitch.vue";
import Generating from "@/components/ui/Generating.vue"; import Generating from "@/components/ui/Generating.vue";
import BlackDialog from "@/components/ui/BlackDialog.vue"; import BlackDialog from "@/components/ui/BlackDialog.vue";
import {useSharedStore} from "@/store/sharedata";
const showDialog = ref(false) const showDialog = ref(false)
const currentVideoUrl = ref('') const currentVideoUrl = ref('')
@ -167,6 +149,7 @@ const row = ref(1)
const images = ref([]) const images = ref([])
const formData = reactive({ const formData = reactive({
client_id: getClientId(),
prompt: '', prompt: '',
expand_prompt: false, expand_prompt: false,
loop: false, loop: false,
@ -174,49 +157,22 @@ const formData = reactive({
end_frame_img: '' end_frame_img: ''
}) })
const socket = ref(null) const store = useSharedStore()
const userId = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/video/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
fetchData()
}
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
});
}
onMounted(()=>{ onMounted(()=>{
checkSession().then(user => { checkSession().then(() => {
userId.value = user.id fetchData(1)
connect() })
store.addMessageHandler("luma",(data) => {
//
if (data.channel !== "luma" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
fetchData(1)
}
}) })
fetchData(1)
}) })
const download = (item) => { const download = (item) => {

View File

@ -45,7 +45,7 @@
<div class="param-line"> <div class="param-line">
<el-button color="#47fff1" :dark="false" round @click="generateAI" :loading="loading"> <el-button color="#47fff1" :dark="false" round @click="generateAI" :loading="loading">
智能生成思维导图 生成思维导图
</el-button> </el-button>
</div> </div>
@ -79,10 +79,7 @@
</el-button> </el-button>
</div> </div>
<div class="markdown" v-if="loading"> <div class="body" id="markmap">
<div :style="{ height: rightBoxHeight + 'px', overflow:'auto',width:'80%' }" v-html="html"></div>
</div>
<div class="body" id="markmap" v-show="!loading">
<svg ref="svgRef" :style="{ height: rightBoxHeight + 'px' }"/> <svg ref="svgRef" :style="{ height: rightBoxHeight + 'px' }"/>
<div id="toolbar"></div> <div id="toolbar"></div>
</div> </div>
@ -94,11 +91,11 @@
</template> </template>
<script setup> <script setup>
import {nextTick, onUnmounted, ref} from 'vue'; import {nextTick, ref} from 'vue';
import {Markmap} from 'markmap-view'; import {Markmap} from 'markmap-view';
import {Transformer} from 'markmap-lib'; import {Transformer} from 'markmap-lib';
import {checkSession, getSystemInfo} from "@/store/cache"; import {checkSession, getSystemInfo} from "@/store/cache";
import {httpGet} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage} from "element-plus"; import {ElMessage} from "element-plus";
import {Download} from "@element-plus/icons-vue"; import {Download} from "@element-plus/icons-vue";
import {Toolbar} from 'markmap-toolbar'; import {Toolbar} from 'markmap-toolbar';
@ -106,11 +103,9 @@ import {useSharedStore} from "@/store/sharedata";
const leftBoxHeight = ref(window.innerHeight - 105) const leftBoxHeight = ref(window.innerHeight - 105)
const rightBoxHeight = ref(window.innerHeight - 115) const rightBoxHeight = ref(window.innerHeight - 115)
const title = ref("")
const prompt = ref("") const prompt = ref("")
const text = ref("") const text = ref("")
const md = require('markdown-it')({breaks: true});
const content = ref(text.value) const content = ref(text.value)
const html = ref("") const html = ref("")
@ -118,13 +113,12 @@ const isLogin = ref(false)
const loginUser = ref({power: 0}) const loginUser = ref({power: 0})
const transformer = new Transformer(); const transformer = new Transformer();
const store = useSharedStore(); const store = useSharedStore();
const loading = ref(false)
const svgRef = ref(null) const svgRef = ref(null)
const markMap = ref(null) const markMap = ref(null)
const models = ref([]) const models = ref([])
const modelID = ref(0) const modelID = ref(0)
const loading = ref(false)
getSystemInfo().then(res => { getSystemInfo().then(res => {
text.value = res.data['mark_map_text'] text.value = res.data['mark_map_text']
@ -147,9 +141,7 @@ getSystemInfo().then(res => {
const initData = () => { const initData = () => {
httpGet("/api/model/list").then(res => { httpGet("/api/model/list").then(res => {
for (let v of res.data) { for (let v of res.data) {
if (v.value.indexOf("gpt-4-gizmo") === -1) { models.value.push(v)
models.value.push(v)
}
} }
modelID.value = models.value[0].id modelID.value = models.value[0].id
}).catch(e => { }).catch(e => {
@ -159,7 +151,6 @@ const initData = () => {
checkSession().then(user => { checkSession().then(user => {
loginUser.value = user loginUser.value = user
isLogin.value = true isLogin.value = true
connect(user.id)
}).catch(() => { }).catch(() => {
}); });
} }
@ -191,74 +182,11 @@ const processContent = (text) => {
return arr.join("\n") return arr.join("\n")
} }
onUnmounted(() => {
if (socket.value !== null) {
socket.value.close()
}
socket.value = null
})
window.onresize = () => { window.onresize = () => {
leftBoxHeight.value = window.innerHeight - 145 leftBoxHeight.value = window.innerHeight - 145
rightBoxHeight.value = window.innerHeight - 85 rightBoxHeight.value = window.innerHeight - 85
} }
const socket = ref(null)
const connect = (userId) => {
if (socket.value !== null) {
socket.value.close()
}
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/markMap/client?user_id=${userId}&model_id=${modelID.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
const model = getModelById(modelID.value)
reader.onload = () => {
const data = JSON.parse(String(reader.result))
switch (data.type) {
case "content":
text.value += data.content
html.value = md.render(processContent(text.value))
break
case "end":
loading.value = false
content.value = processContent(text.value)
loginUser.value.power -= model.power
nextTick(() => update())
break
case "error":
loading.value = false
ElMessage.error(data.content)
break
}
}
}
})
_socket.addEventListener('close', () => {
loading.value = false
checkSession().then(() => {
connect(userId)
}).catch(() => {
})
});
}
const generate = () => { const generate = () => {
text.value = content.value text.value = content.value
update() update()
@ -276,19 +204,26 @@ const generateAI = () => {
return return
} }
loading.value = true loading.value = true
socket.value.send(JSON.stringify({type: "message", content: prompt.value})) httpPost("/api/markMap/gen", {
} prompt:prompt.value,
model_id: modelID.value
const changeModel = () => { }).then(res => {
if (socket.value !== null) { text.value = res.data
socket.value.send(JSON.stringify({type: "model_id", content: modelID.value})) content.value = processContent(text.value)
} const model = getModelById(modelID.value)
loginUser.value.power -= model.power
nextTick(() => update())
loading.value = false
}).catch(e => {
ElMessage.error("生成思维导图失败:" + e.message)
loading.value = false
})
} }
const getModelById = (modelId) => { const getModelById = (modelId) => {
for (let e of models.value) { for (let m of models.value) {
if (e.id === modelId) { if (m.id === modelId) {
return e return m
} }
} }
} }

View File

@ -300,13 +300,14 @@ import MusicPlayer from "@/components/MusicPlayer.vue";
import {compact} from "lodash"; import {compact} from "lodash";
import {httpDownload, httpGet, httpPost} from "@/utils/http"; import {httpDownload, httpGet, httpPost} from "@/utils/http";
import {showMessageError, showMessageOK} from "@/utils/dialog"; import {showMessageError, showMessageOK} from "@/utils/dialog";
import {checkSession} from "@/store/cache"; import {checkSession, getClientId} from "@/store/cache";
import {ElMessage, ElMessageBox} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import {formatTime, replaceImg} from "@/utils/libs"; import {formatTime, replaceImg} from "@/utils/libs";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import BlackDialog from "@/components/ui/BlackDialog.vue"; import BlackDialog from "@/components/ui/BlackDialog.vue";
import Compressor from "compressorjs"; import Compressor from "compressorjs";
import Generating from "@/components/ui/Generating.vue"; import Generating from "@/components/ui/Generating.vue";
import {useSharedStore} from "@/store/sharedata";
const winHeight = ref(window.innerHeight - 50) const winHeight = ref(window.innerHeight - 50)
const custom = ref(false) const custom = ref(false)
@ -333,6 +334,7 @@ const tags = ref([
{label: "嘻哈", value: "hip hop"}, {label: "嘻哈", value: "hip hop"},
]) ])
const data = ref({ const data = ref({
client_id: getClientId(),
model: "chirp-v3-0", model: "chirp-v3-0",
tags: "", tags: "",
lyrics: "", lyrics: "",
@ -354,45 +356,7 @@ const refSong = ref(null)
const showDialog = ref(false) const showDialog = ref(false)
const editData = ref({title:"",cover:"",id:0}) const editData = ref({title:"",cover:"",id:0})
const promptPlaceholder = ref('请在这里输入你自己写的歌词...') const promptPlaceholder = ref('请在这里输入你自己写的歌词...')
const store = useSharedStore()
const socket = ref(null)
const userId = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/suno/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
console.log(message)
if (message === "FINISH" || message === "FAIL") {
fetchData()
}
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
});
}
const clipboard = ref(null) const clipboard = ref(null)
onMounted(() => { onMounted(() => {
clipboard.value = new Clipboard('.copy-link'); clipboard.value = new Clipboard('.copy-link');
@ -405,10 +369,19 @@ onMounted(() => {
}) })
checkSession().then(user => { checkSession().then(user => {
userId.value = user.id fetchData(1)
connect() })
store.addMessageHandler("suno",(data) => {
//
if (data.channel !== "suno" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
fetchData(1)
}
}) })
fetchData(1)
}) })
onUnmounted(() => { onUnmounted(() => {