mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
websocket api refactor is ready
This commit is contained in:
parent
00a8bc6784
commit
e28a12a1ee
@ -1,4 +1,8 @@
|
||||
# 更新日志
|
||||
## v4.1.5
|
||||
* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
|
||||
* Bug修复:兼容手机端原生微信支付和支付宝支付渠道
|
||||
* Bug修复:修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题
|
||||
## v4.1.4
|
||||
* 功能优化:用户文件列表组件增加分页功能支持
|
||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
|
||||
|
@ -72,18 +72,20 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
ClientId string `json:"client_id"`
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
|
||||
Power int `json:"power"`
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
@ -107,13 +109,14 @@ const (
|
||||
)
|
||||
|
||||
type VideoTask struct {
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
}
|
||||
|
||||
type VideoParams struct {
|
||||
|
@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.dallService.PushTask(types.DallTask{
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
ClientId: data.ClientId,
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
})
|
||||
|
||||
client := h.dallService.Clients.Get(job.UserId)
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
@ -8,23 +8,15 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MarkMapHandler 生成思维导图
|
||||
@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ModelId int `json:"model_id"`
|
||||
}
|
||||
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, userId)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query user info: %v", res.Error)
|
||||
err := h.DB.Where("id", userId).First(&user, userId).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query user info")
|
||||
return
|
||||
}
|
||||
var chatModel model.ChatModel
|
||||
res = h.DB.Where("id", modelId).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query chat model: %v", res.Error)
|
||||
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query chat model")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < chatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
||||
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power))
|
||||
return
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 0)
|
||||
@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
### 支付宝
|
||||
### 微信
|
||||
|
||||
另外,除此之外不要任何解释性语句。
|
||||
请直接生成结果,不要任何解释性语句。
|
||||
`})
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)})
|
||||
var req = types.ApiRequest{
|
||||
Model: chatModel.Value,
|
||||
Stream: true,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
response, err := h.doRequest(req, chatModel, &apiKey)
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)})
|
||||
content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
// 循环读取 Chunk 消息
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
return fmt.Errorf("error with decode data: %v", line)
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" {
|
||||
break
|
||||
}
|
||||
|
||||
utils.SendMsg(client, types.ReplyMessage{
|
||||
Type: types.MsgTypeText,
|
||||
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
|
||||
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
resp.ERROR(c, "error with save power log, "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id", chatModel.KeyId)
|
||||
} else { // use the last unused key
|
||||
session = session.Where("type", "chat").
|
||||
Where("enabled", true).Order("last_used_at ASC")
|
||||
}
|
||||
|
||||
res := session.First(apiKey)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||
return client.Do(request)
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if item.Progress < 100 {
|
||||
// 从 leveldb 中获取图片预览数据
|
||||
var imageData string
|
||||
err = h.leveldb.Get(item.TaskId, &imageData)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + imageData
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.sunoService.PushTask(types.SunoTask{
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client := h.sunoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
|
||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||
@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.videoService.PushTask(types.VideoTask{
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
})
|
||||
|
||||
// update user's power
|
||||
@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
client := h.videoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 只有失败或者超时的任务才能删除
|
||||
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
|
||||
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
|
||||
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("Receive a message:%+v", message)
|
||||
logger.Debugf("Receive a message:%+v", message)
|
||||
if message.Type == types.MsgTypePing {
|
||||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||
continue
|
||||
|
@ -34,19 +34,21 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,6 +69,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
s.clientIds[task.JobId] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
@ -74,7 +77,7 @@ func (s *Service) Run() {
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
||||
@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@ -67,7 +67,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
|
@ -33,7 +33,6 @@ type Service struct {
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
leveldb *store.LevelDB
|
||||
wsService *service.WebsocketService
|
||||
}
|
||||
|
||||
@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
}
|
||||
@ -62,7 +60,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
@ -72,7 +70,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
@ -126,9 +124,8 @@ type Txt2ImgResp struct {
|
||||
|
||||
// TaskProgressResp 任务进度响应实体
|
||||
type TaskProgressResp struct {
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
CurrentImage string `json:"current_image"`
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
|
@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[string]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
wsService: wsService,
|
||||
clientIds: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -96,7 +98,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@ -105,6 +107,7 @@ func (s *Service) Run() {
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
s.clientIds[r.Data] = task.ClientId
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
logger.Debugf("client id: %+v", s.wsService.Clients)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
logger.Debugf("%+v", client)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSuno, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,7 +87,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@ -93,6 +95,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
if task.ClientId != "" {
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
}
|
||||
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
@ -105,7 +111,7 @@ func (s *Service) Run() {
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("Receive notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChLuma, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
|
@ -45,18 +45,25 @@ type apiRes struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
return SendOpenAIMessage(db, messages, modelName, keyId)
|
||||
}
|
||||
|
||||
func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
|
||||
if keyId > 0 {
|
||||
session = session.Where("id", keyId)
|
||||
}
|
||||
err := session.First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err)
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
client := req.C()
|
||||
|
1
database/update-v4.1.5.sql
Normal file
1
database/update-v4.1.5.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE `chatgpt_power_logs` CHANGE `remark` `remark` VARCHAR(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '备注';
|
@ -208,7 +208,7 @@ import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue";
|
||||
import {httpGet, httpPost} from "@/utils/http";
|
||||
import {ElMessage, ElMessageBox} from "element-plus";
|
||||
import Clipboard from "clipboard";
|
||||
import {checkSession, getSystemInfo} from "@/store/cache";
|
||||
import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
|
||||
import {useSharedStore} from "@/store/sharedata";
|
||||
import TaskList from "@/components/TaskList.vue";
|
||||
import BackTop from "@/components/BackTop.vue";
|
||||
@ -240,6 +240,7 @@ const styles = [
|
||||
{name: "自然", value: "natural"}
|
||||
]
|
||||
const params = ref({
|
||||
client_id: getClientId(),
|
||||
quality: "standard",
|
||||
size: "1024x1024",
|
||||
style: "vivid",
|
||||
@ -268,14 +269,24 @@ onMounted(() => {
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取系统配置失败:" + e.message)
|
||||
})
|
||||
|
||||
store.addMessageHandler("dall",(data) => {
|
||||
// 丢弃无关消息
|
||||
if (data.channel !== "dall" || data.clientId !== getClientId()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (data.body === "FINISH" || data.body === "FAIL") {
|
||||
page.value = 0
|
||||
isOver.value = false
|
||||
fetchFinishJobs()
|
||||
}
|
||||
nextTick(() => fetchRunningJobs())
|
||||
})
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
clipboard.value.destroy()
|
||||
if (socket.value !== null) {
|
||||
socket.value.close()
|
||||
socket.value = null
|
||||
}
|
||||
})
|
||||
|
||||
const initData = () => {
|
||||
@ -287,51 +298,10 @@ const initData = () => {
|
||||
page.value = 0
|
||||
fetchRunningJobs()
|
||||
fetchFinishJobs()
|
||||
connect()
|
||||
}).catch(() => {
|
||||
});
|
||||
}
|
||||
|
||||
const socket = ref(null)
|
||||
const heartbeatHandle = ref(null)
|
||||
const connect = () => {
|
||||
let host = process.env.VUE_APP_WS_HOST
|
||||
if (host === '') {
|
||||
if (location.protocol === 'https:') {
|
||||
host = 'wss://' + location.host;
|
||||
} else {
|
||||
host = 'ws://' + location.host;
|
||||
}
|
||||
}
|
||||
|
||||
const _socket = new WebSocket(host + `/api/dall/client?user_id=${userId.value}`);
|
||||
_socket.addEventListener('open', () => {
|
||||
socket.value = _socket;
|
||||
});
|
||||
|
||||
_socket.addEventListener('message', event => {
|
||||
if (event.data instanceof Blob) {
|
||||
const reader = new FileReader();
|
||||
reader.readAsText(event.data, "UTF-8")
|
||||
reader.onload = () => {
|
||||
const message = String(reader.result)
|
||||
if (message === "FINISH" || message === "FAIL") {
|
||||
page.value = 0
|
||||
isOver.value = false
|
||||
fetchFinishJobs(page.value)
|
||||
}
|
||||
nextTick(() => fetchRunningJobs())
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
_socket.addEventListener('close', () => {
|
||||
if (socket.value !== null) {
|
||||
connect()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const fetchRunningJobs = () => {
|
||||
if (!isLogin.value) {
|
||||
return
|
||||
@ -391,6 +361,7 @@ const generate = () => {
|
||||
httpPost("/api/dall/image", params.value).then(() => {
|
||||
ElMessage.success("任务执行成功!")
|
||||
power.value -= dallPower.value
|
||||
fetchRunningJobs()
|
||||
}).catch(e => {
|
||||
ElMessage.error("任务执行失败:" + e.message)
|
||||
})
|
||||
|
@ -55,25 +55,6 @@
|
||||
<el-container class="video-container" v-loading="loading" element-loading-background="rgba(100,100,100,0.3)">
|
||||
<h2 class="h-title">你的作品</h2>
|
||||
|
||||
<!-- <el-row :gutter="20" class="videos" v-if="!noData">-->
|
||||
<!-- <el-col :span="8" class="item" :key="item.id" v-for="item in videos">-->
|
||||
<!-- <div class="video-box" @mouseover="item.playing = true" @mouseout="item.playing = false">-->
|
||||
<!-- <img :src="item.cover" :alt="item.name" v-show="!item.playing"/>-->
|
||||
<!-- <video :src="item.url" preload="auto" :autoplay="true" loop="loop" muted="muted" v-show="item.playing">-->
|
||||
<!-- 您的浏览器不支持视频播放-->
|
||||
<!-- </video>-->
|
||||
<!-- </div>-->
|
||||
<!-- <div class="video-name">{{item.name}}</div>-->
|
||||
<!-- <div class="opts">-->
|
||||
<!-- <button class="btn" @click="download(item)" :disabled="item.downloading">-->
|
||||
<!-- <i class="iconfont icon-download" v-if="!item.downloading"></i>-->
|
||||
<!-- <el-image src="/images/loading.gif" fit="cover" v-else />-->
|
||||
<!-- <span>下载</span>-->
|
||||
<!-- </button>-->
|
||||
<!-- </div>-->
|
||||
<!-- </el-col>-->
|
||||
<!-- </el-row>-->
|
||||
|
||||
<div class="list-box" v-if="!noData">
|
||||
<div v-for="item in list" :key="item.id">
|
||||
<div class="item">
|
||||
@ -153,13 +134,14 @@
|
||||
import {onMounted, reactive, ref} from "vue";
|
||||
import {CircleCloseFilled} from "@element-plus/icons-vue";
|
||||
import {httpDownload, httpPost, httpGet} from "@/utils/http";
|
||||
import {checkSession} from "@/store/cache";
|
||||
import {checkSession, getClientId} from "@/store/cache";
|
||||
import {showMessageError, showMessageOK} from "@/utils/dialog";
|
||||
import { replaceImg } from "@/utils/libs"
|
||||
import {ElMessage, ElMessageBox} from "element-plus";
|
||||
import BlackSwitch from "@/components/ui/BlackSwitch.vue";
|
||||
import Generating from "@/components/ui/Generating.vue";
|
||||
import BlackDialog from "@/components/ui/BlackDialog.vue";
|
||||
import {useSharedStore} from "@/store/sharedata";
|
||||
|
||||
const showDialog = ref(false)
|
||||
const currentVideoUrl = ref('')
|
||||
@ -167,6 +149,7 @@ const row = ref(1)
|
||||
const images = ref([])
|
||||
|
||||
const formData = reactive({
|
||||
client_id: getClientId(),
|
||||
prompt: '',
|
||||
expand_prompt: false,
|
||||
loop: false,
|
||||
@ -174,49 +157,22 @@ const formData = reactive({
|
||||
end_frame_img: ''
|
||||
})
|
||||
|
||||
const socket = ref(null)
|
||||
const userId = ref(0)
|
||||
const connect = () => {
|
||||
let host = process.env.VUE_APP_WS_HOST
|
||||
if (host === '') {
|
||||
if (location.protocol === 'https:') {
|
||||
host = 'wss://' + location.host;
|
||||
} else {
|
||||
host = 'ws://' + location.host;
|
||||
}
|
||||
}
|
||||
|
||||
const _socket = new WebSocket(host + `/api/video/client?user_id=${userId.value}`);
|
||||
_socket.addEventListener('open', () => {
|
||||
socket.value = _socket;
|
||||
});
|
||||
|
||||
_socket.addEventListener('message', event => {
|
||||
if (event.data instanceof Blob) {
|
||||
const reader = new FileReader();
|
||||
reader.readAsText(event.data, "UTF-8")
|
||||
reader.onload = () => {
|
||||
const message = String(reader.result)
|
||||
if (message === "FINISH" || message === "FAIL") {
|
||||
fetchData()
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
_socket.addEventListener('close', () => {
|
||||
if (socket.value !== null) {
|
||||
connect()
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const store = useSharedStore()
|
||||
onMounted(()=>{
|
||||
checkSession().then(user => {
|
||||
userId.value = user.id
|
||||
connect()
|
||||
checkSession().then(() => {
|
||||
fetchData(1)
|
||||
})
|
||||
|
||||
store.addMessageHandler("luma",(data) => {
|
||||
// 丢弃无关消息
|
||||
if (data.channel !== "luma" || data.clientId !== getClientId()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (data.body === "FINISH" || data.body === "FAIL") {
|
||||
fetchData(1)
|
||||
}
|
||||
})
|
||||
fetchData(1)
|
||||
})
|
||||
|
||||
const download = (item) => {
|
||||
|
@ -45,7 +45,7 @@
|
||||
|
||||
<div class="param-line">
|
||||
<el-button color="#47fff1" :dark="false" round @click="generateAI" :loading="loading">
|
||||
智能生成思维导图
|
||||
生成思维导图
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
@ -79,10 +79,7 @@
|
||||
</el-button>
|
||||
</div>
|
||||
|
||||
<div class="markdown" v-if="loading">
|
||||
<div :style="{ height: rightBoxHeight + 'px', overflow:'auto',width:'80%' }" v-html="html"></div>
|
||||
</div>
|
||||
<div class="body" id="markmap" v-show="!loading">
|
||||
<div class="body" id="markmap">
|
||||
<svg ref="svgRef" :style="{ height: rightBoxHeight + 'px' }"/>
|
||||
<div id="toolbar"></div>
|
||||
</div>
|
||||
@ -94,11 +91,11 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import {nextTick, onUnmounted, ref} from 'vue';
|
||||
import {nextTick, ref} from 'vue';
|
||||
import {Markmap} from 'markmap-view';
|
||||
import {Transformer} from 'markmap-lib';
|
||||
import {checkSession, getSystemInfo} from "@/store/cache";
|
||||
import {httpGet} from "@/utils/http";
|
||||
import {httpGet, httpPost} from "@/utils/http";
|
||||
import {ElMessage} from "element-plus";
|
||||
import {Download} from "@element-plus/icons-vue";
|
||||
import {Toolbar} from 'markmap-toolbar';
|
||||
@ -106,11 +103,9 @@ import {useSharedStore} from "@/store/sharedata";
|
||||
|
||||
const leftBoxHeight = ref(window.innerHeight - 105)
|
||||
const rightBoxHeight = ref(window.innerHeight - 115)
|
||||
const title = ref("")
|
||||
|
||||
const prompt = ref("")
|
||||
const text = ref("")
|
||||
const md = require('markdown-it')({breaks: true});
|
||||
const content = ref(text.value)
|
||||
const html = ref("")
|
||||
|
||||
@ -118,13 +113,12 @@ const isLogin = ref(false)
|
||||
const loginUser = ref({power: 0})
|
||||
const transformer = new Transformer();
|
||||
const store = useSharedStore();
|
||||
|
||||
const loading = ref(false)
|
||||
|
||||
const svgRef = ref(null)
|
||||
const markMap = ref(null)
|
||||
const models = ref([])
|
||||
const modelID = ref(0)
|
||||
const loading = ref(false)
|
||||
|
||||
getSystemInfo().then(res => {
|
||||
text.value = res.data['mark_map_text']
|
||||
@ -147,9 +141,7 @@ getSystemInfo().then(res => {
|
||||
const initData = () => {
|
||||
httpGet("/api/model/list").then(res => {
|
||||
for (let v of res.data) {
|
||||
if (v.value.indexOf("gpt-4-gizmo") === -1) {
|
||||
models.value.push(v)
|
||||
}
|
||||
models.value.push(v)
|
||||
}
|
||||
modelID.value = models.value[0].id
|
||||
}).catch(e => {
|
||||
@ -159,7 +151,6 @@ const initData = () => {
|
||||
checkSession().then(user => {
|
||||
loginUser.value = user
|
||||
isLogin.value = true
|
||||
connect(user.id)
|
||||
}).catch(() => {
|
||||
});
|
||||
}
|
||||
@ -191,74 +182,11 @@ const processContent = (text) => {
|
||||
return arr.join("\n")
|
||||
}
|
||||
|
||||
onUnmounted(() => {
|
||||
if (socket.value !== null) {
|
||||
socket.value.close()
|
||||
}
|
||||
socket.value = null
|
||||
})
|
||||
|
||||
window.onresize = () => {
|
||||
leftBoxHeight.value = window.innerHeight - 145
|
||||
rightBoxHeight.value = window.innerHeight - 85
|
||||
}
|
||||
|
||||
const socket = ref(null)
|
||||
const connect = (userId) => {
|
||||
if (socket.value !== null) {
|
||||
socket.value.close()
|
||||
}
|
||||
|
||||
let host = process.env.VUE_APP_WS_HOST
|
||||
if (host === '') {
|
||||
if (location.protocol === 'https:') {
|
||||
host = 'wss://' + location.host;
|
||||
} else {
|
||||
host = 'ws://' + location.host;
|
||||
}
|
||||
}
|
||||
|
||||
const _socket = new WebSocket(host + `/api/markMap/client?user_id=${userId}&model_id=${modelID.value}`);
|
||||
_socket.addEventListener('open', () => {
|
||||
socket.value = _socket;
|
||||
});
|
||||
|
||||
_socket.addEventListener('message', event => {
|
||||
if (event.data instanceof Blob) {
|
||||
const reader = new FileReader();
|
||||
reader.readAsText(event.data, "UTF-8")
|
||||
const model = getModelById(modelID.value)
|
||||
reader.onload = () => {
|
||||
const data = JSON.parse(String(reader.result))
|
||||
switch (data.type) {
|
||||
case "content":
|
||||
text.value += data.content
|
||||
html.value = md.render(processContent(text.value))
|
||||
break
|
||||
case "end":
|
||||
loading.value = false
|
||||
content.value = processContent(text.value)
|
||||
loginUser.value.power -= model.power
|
||||
nextTick(() => update())
|
||||
break
|
||||
case "error":
|
||||
loading.value = false
|
||||
ElMessage.error(data.content)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
_socket.addEventListener('close', () => {
|
||||
loading.value = false
|
||||
checkSession().then(() => {
|
||||
connect(userId)
|
||||
}).catch(() => {
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
const generate = () => {
|
||||
text.value = content.value
|
||||
update()
|
||||
@ -276,19 +204,26 @@ const generateAI = () => {
|
||||
return
|
||||
}
|
||||
loading.value = true
|
||||
socket.value.send(JSON.stringify({type: "message", content: prompt.value}))
|
||||
}
|
||||
|
||||
const changeModel = () => {
|
||||
if (socket.value !== null) {
|
||||
socket.value.send(JSON.stringify({type: "model_id", content: modelID.value}))
|
||||
}
|
||||
httpPost("/api/markMap/gen", {
|
||||
prompt:prompt.value,
|
||||
model_id: modelID.value
|
||||
}).then(res => {
|
||||
text.value = res.data
|
||||
content.value = processContent(text.value)
|
||||
const model = getModelById(modelID.value)
|
||||
loginUser.value.power -= model.power
|
||||
nextTick(() => update())
|
||||
loading.value = false
|
||||
}).catch(e => {
|
||||
ElMessage.error("生成思维导图失败:" + e.message)
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
|
||||
const getModelById = (modelId) => {
|
||||
for (let e of models.value) {
|
||||
if (e.id === modelId) {
|
||||
return e
|
||||
for (let m of models.value) {
|
||||
if (m.id === modelId) {
|
||||
return m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -300,13 +300,14 @@ import MusicPlayer from "@/components/MusicPlayer.vue";
|
||||
import {compact} from "lodash";
|
||||
import {httpDownload, httpGet, httpPost} from "@/utils/http";
|
||||
import {showMessageError, showMessageOK} from "@/utils/dialog";
|
||||
import {checkSession} from "@/store/cache";
|
||||
import {checkSession, getClientId} from "@/store/cache";
|
||||
import {ElMessage, ElMessageBox} from "element-plus";
|
||||
import {formatTime, replaceImg} from "@/utils/libs";
|
||||
import Clipboard from "clipboard";
|
||||
import BlackDialog from "@/components/ui/BlackDialog.vue";
|
||||
import Compressor from "compressorjs";
|
||||
import Generating from "@/components/ui/Generating.vue";
|
||||
import {useSharedStore} from "@/store/sharedata";
|
||||
|
||||
const winHeight = ref(window.innerHeight - 50)
|
||||
const custom = ref(false)
|
||||
@ -333,6 +334,7 @@ const tags = ref([
|
||||
{label: "嘻哈", value: "hip hop"},
|
||||
])
|
||||
const data = ref({
|
||||
client_id: getClientId(),
|
||||
model: "chirp-v3-0",
|
||||
tags: "",
|
||||
lyrics: "",
|
||||
@ -354,45 +356,7 @@ const refSong = ref(null)
|
||||
const showDialog = ref(false)
|
||||
const editData = ref({title:"",cover:"",id:0})
|
||||
const promptPlaceholder = ref('请在这里输入你自己写的歌词...')
|
||||
|
||||
const socket = ref(null)
|
||||
const userId = ref(0)
|
||||
const connect = () => {
|
||||
let host = process.env.VUE_APP_WS_HOST
|
||||
if (host === '') {
|
||||
if (location.protocol === 'https:') {
|
||||
host = 'wss://' + location.host;
|
||||
} else {
|
||||
host = 'ws://' + location.host;
|
||||
}
|
||||
}
|
||||
|
||||
const _socket = new WebSocket(host + `/api/suno/client?user_id=${userId.value}`);
|
||||
_socket.addEventListener('open', () => {
|
||||
socket.value = _socket;
|
||||
});
|
||||
|
||||
_socket.addEventListener('message', event => {
|
||||
if (event.data instanceof Blob) {
|
||||
const reader = new FileReader();
|
||||
reader.readAsText(event.data, "UTF-8")
|
||||
reader.onload = () => {
|
||||
const message = String(reader.result)
|
||||
console.log(message)
|
||||
if (message === "FINISH" || message === "FAIL") {
|
||||
fetchData()
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
_socket.addEventListener('close', () => {
|
||||
if (socket.value !== null) {
|
||||
connect()
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const store = useSharedStore()
|
||||
const clipboard = ref(null)
|
||||
onMounted(() => {
|
||||
clipboard.value = new Clipboard('.copy-link');
|
||||
@ -405,10 +369,19 @@ onMounted(() => {
|
||||
})
|
||||
|
||||
checkSession().then(user => {
|
||||
userId.value = user.id
|
||||
connect()
|
||||
fetchData(1)
|
||||
})
|
||||
|
||||
store.addMessageHandler("suno",(data) => {
|
||||
// 丢弃无关消息
|
||||
if (data.channel !== "suno" || data.clientId !== getClientId()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (data.body === "FINISH" || data.body === "FAIL") {
|
||||
fetchData(1)
|
||||
}
|
||||
})
|
||||
fetchData(1)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
|
Loading…
Reference in New Issue
Block a user