mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	websocket api refactor is ready
This commit is contained in:
		@@ -1,4 +1,8 @@
 | 
				
			|||||||
# 更新日志
 | 
					# 更新日志
 | 
				
			||||||
 | 
					## v4.1.5
 | 
				
			||||||
 | 
					* 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接
 | 
				
			||||||
 | 
					* Bug修复:兼容手机端原生微信支付和支付宝支付渠道
 | 
				
			||||||
 | 
					* Bug修复:修复删除绘图任务时候因为字段长度过短导致SQL执行失败问题
 | 
				
			||||||
## v4.1.4
 | 
					## v4.1.4
 | 
				
			||||||
* 功能优化:用户文件列表组件增加分页功能支持
 | 
					* 功能优化:用户文件列表组件增加分页功能支持
 | 
				
			||||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
 | 
					* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -72,18 +72,20 @@ type SdTaskParams struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// DallTask DALL-E task
 | 
					// DallTask DALL-E task
 | 
				
			||||||
type DallTask struct {
 | 
					type DallTask struct {
 | 
				
			||||||
	JobId   uint   `json:"job_id"`
 | 
						ClientId string `json:"client_id"`
 | 
				
			||||||
	UserId  uint   `json:"user_id"`
 | 
						JobId    uint   `json:"job_id"`
 | 
				
			||||||
	Prompt  string `json:"prompt"`
 | 
						UserId   uint   `json:"user_id"`
 | 
				
			||||||
	N       int    `json:"n"`
 | 
						Prompt   string `json:"prompt"`
 | 
				
			||||||
	Quality string `json:"quality"`
 | 
						N        int    `json:"n"`
 | 
				
			||||||
	Size    string `json:"size"`
 | 
						Quality  string `json:"quality"`
 | 
				
			||||||
	Style   string `json:"style"`
 | 
						Size     string `json:"size"`
 | 
				
			||||||
 | 
						Style    string `json:"style"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Power int `json:"power"`
 | 
						Power int `json:"power"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type SunoTask struct {
 | 
					type SunoTask struct {
 | 
				
			||||||
 | 
						ClientId     string `json:"client_id"`
 | 
				
			||||||
	Id           uint   `json:"id"`
 | 
						Id           uint   `json:"id"`
 | 
				
			||||||
	Channel      string `json:"channel"`
 | 
						Channel      string `json:"channel"`
 | 
				
			||||||
	UserId       int    `json:"user_id"`
 | 
						UserId       int    `json:"user_id"`
 | 
				
			||||||
@@ -107,13 +109,14 @@ const (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type VideoTask struct {
 | 
					type VideoTask struct {
 | 
				
			||||||
	Id      uint        `json:"id"`
 | 
						ClientId string      `json:"client_id"`
 | 
				
			||||||
	Channel string      `json:"channel"`
 | 
						Id       uint        `json:"id"`
 | 
				
			||||||
	UserId  int         `json:"user_id"`
 | 
						Channel  string      `json:"channel"`
 | 
				
			||||||
	Type    string      `json:"type"`
 | 
						UserId   int         `json:"user_id"`
 | 
				
			||||||
	TaskId  string      `json:"task_id"`
 | 
						Type     string      `json:"type"`
 | 
				
			||||||
	Prompt  string      `json:"prompt"` // 提示词
 | 
						TaskId   string      `json:"task_id"`
 | 
				
			||||||
	Params  VideoParams `json:"params"`
 | 
						Prompt   string      `json:"prompt"` // 提示词
 | 
				
			||||||
 | 
						Params   VideoParams `json:"params"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type VideoParams struct {
 | 
					type VideoParams struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h.dallService.PushTask(types.DallTask{
 | 
						h.dallService.PushTask(types.DallTask{
 | 
				
			||||||
		JobId:   job.Id,
 | 
							ClientId: data.ClientId,
 | 
				
			||||||
		UserId:  uint(userId),
 | 
							JobId:    job.Id,
 | 
				
			||||||
		Prompt:  data.Prompt,
 | 
							UserId:   uint(userId),
 | 
				
			||||||
		Quality: data.Quality,
 | 
							Prompt:   data.Prompt,
 | 
				
			||||||
		Size:    data.Size,
 | 
							Quality:  data.Quality,
 | 
				
			||||||
		Style:   data.Style,
 | 
							Size:     data.Size,
 | 
				
			||||||
		Power:   job.Power,
 | 
							Style:    data.Style,
 | 
				
			||||||
 | 
							Power:    job.Power,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := h.dallService.Clients.Get(job.UserId)
 | 
					 | 
				
			||||||
	if client != nil {
 | 
					 | 
				
			||||||
		_ = client.Send([]byte("Task Updated"))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,23 +8,15 @@ package handler
 | 
				
			|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bufio"
 | 
					 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"geekai/core"
 | 
						"geekai/core"
 | 
				
			||||||
	"geekai/core/types"
 | 
						"geekai/core/types"
 | 
				
			||||||
	"geekai/service"
 | 
						"geekai/service"
 | 
				
			||||||
	"geekai/store/model"
 | 
						"geekai/store/model"
 | 
				
			||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"net/url"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// MarkMapHandler 生成思维导图
 | 
					// MarkMapHandler 生成思维导图
 | 
				
			||||||
@@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Generate 生成思维导图
 | 
					// Generate 生成思维导图
 | 
				
			||||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
 | 
					func (h *MarkMapHandler) Generate(c *gin.Context) {
 | 
				
			||||||
 | 
						var data struct {
 | 
				
			||||||
 | 
							Prompt  string `json:"prompt"`
 | 
				
			||||||
 | 
							ModelId int    `json:"model_id"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
						if err := c.ShouldBindJSON(&data); err != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, types.InvalidArgs)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
 | 
						userId := h.GetLoginUserId(c)
 | 
				
			||||||
	var user model.User
 | 
						var user model.User
 | 
				
			||||||
	res := h.DB.Model(&model.User{}).First(&user, userId)
 | 
						err := h.DB.Where("id", userId).First(&user, userId).Error
 | 
				
			||||||
	if res.Error != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("error with query user info: %v", res.Error)
 | 
							resp.ERROR(c, "error with query user info")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var chatModel model.ChatModel
 | 
						var chatModel model.ChatModel
 | 
				
			||||||
	res = h.DB.Where("id", modelId).First(&chatModel)
 | 
						err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
 | 
				
			||||||
	if res.Error != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("error with query chat model: %v", res.Error)
 | 
							resp.ERROR(c, "error with query chat model")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if user.Power < chatModel.Power {
 | 
						if user.Power < chatModel.Power {
 | 
				
			||||||
		return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
 | 
							resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	messages := make([]interface{}, 0)
 | 
						messages := make([]interface{}, 0)
 | 
				
			||||||
@@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
				
			|||||||
### 支付宝
 | 
					### 支付宝
 | 
				
			||||||
### 微信
 | 
					### 微信
 | 
				
			||||||
 | 
					
 | 
				
			||||||
另外,除此之外不要任何解释性语句。
 | 
					请直接生成结果,不要任何解释性语句。
 | 
				
			||||||
`})
 | 
					`})
 | 
				
			||||||
	messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)})
 | 
						messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)})
 | 
				
			||||||
	var req = types.ApiRequest{
 | 
						content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
 | 
				
			||||||
		Model:    chatModel.Value,
 | 
					 | 
				
			||||||
		Stream:   true,
 | 
					 | 
				
			||||||
		Messages: messages,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var apiKey model.ApiKey
 | 
					 | 
				
			||||||
	response, err := h.doRequest(req, chatModel, &apiKey)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("请求 OpenAI API 失败: %s", err)
 | 
							resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
 | 
				
			||||||
	}
 | 
							return
 | 
				
			||||||
 | 
					 | 
				
			||||||
	defer response.Body.Close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	contentType := response.Header.Get("Content-Type")
 | 
					 | 
				
			||||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
					 | 
				
			||||||
		// 循环读取 Chunk 消息
 | 
					 | 
				
			||||||
		scanner := bufio.NewScanner(response.Body)
 | 
					 | 
				
			||||||
		for scanner.Scan() {
 | 
					 | 
				
			||||||
			line := scanner.Text()
 | 
					 | 
				
			||||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			var responseBody = types.ApiResponse{}
 | 
					 | 
				
			||||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
					 | 
				
			||||||
			if err != nil { // 数据解析出错
 | 
					 | 
				
			||||||
				return fmt.Errorf("error with decode data: %v", line)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if responseBody.Choices[0].FinishReason == "stop" {
 | 
					 | 
				
			||||||
				break
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			utils.SendMsg(client, types.ReplyMessage{
 | 
					 | 
				
			||||||
				Type: types.MsgTypeText,
 | 
					 | 
				
			||||||
				Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
		} // end for
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		body, _ := io.ReadAll(response.Body)
 | 
					 | 
				
			||||||
		return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 扣减算力
 | 
						// 扣减算力
 | 
				
			||||||
	if chatModel.Power > 0 {
 | 
						if chatModel.Power > 0 {
 | 
				
			||||||
		err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
 | 
							err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
 | 
				
			||||||
			Type:   types.PowerConsume,
 | 
								Type:   types.PowerConsume,
 | 
				
			||||||
			Model:  chatModel.Value,
 | 
								Model:  chatModel.Value,
 | 
				
			||||||
			Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
 | 
								Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								resp.ERROR(c, "error with save power log, "+err.Error())
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						resp.SUCCESS(c, content)
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	session := h.DB.Session(&gorm.Session{})
 | 
					 | 
				
			||||||
	// if the chat model bind a KEY, use it directly
 | 
					 | 
				
			||||||
	if chatModel.KeyId > 0 {
 | 
					 | 
				
			||||||
		session = session.Where("id", chatModel.KeyId)
 | 
					 | 
				
			||||||
	} else { // use the last unused key
 | 
					 | 
				
			||||||
		session = session.Where("type", "chat").
 | 
					 | 
				
			||||||
			Where("enabled", true).Order("last_used_at ASC")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	res := session.First(apiKey)
 | 
					 | 
				
			||||||
	if res.Error != nil {
 | 
					 | 
				
			||||||
		return nil, errors.New("no available key, please import key")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
 | 
					 | 
				
			||||||
	// 更新 API KEY 的最后使用时间
 | 
					 | 
				
			||||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// 创建 HttpClient 请求对象
 | 
					 | 
				
			||||||
	var client *http.Client
 | 
					 | 
				
			||||||
	requestBody, err := json.Marshal(req)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	request.Header.Set("Content-Type", "application/json")
 | 
					 | 
				
			||||||
	if len(apiKey.ProxyURL) > 5 { // 使用代理
 | 
					 | 
				
			||||||
		proxy, _ := url.Parse(apiKey.ProxyURL)
 | 
					 | 
				
			||||||
		client = &http.Client{
 | 
					 | 
				
			||||||
			Transport: &http.Transport{
 | 
					 | 
				
			||||||
				Proxy: http.ProxyURL(proxy),
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		client = http.DefaultClient
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
					 | 
				
			||||||
	logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
 | 
					 | 
				
			||||||
	return client.Do(request)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		if item.Progress < 100 {
 | 
					 | 
				
			||||||
			// 从 leveldb 中获取图片预览数据
 | 
					 | 
				
			||||||
			var imageData string
 | 
					 | 
				
			||||||
			err = h.leveldb.Get(item.TaskId, &imageData)
 | 
					 | 
				
			||||||
			if err == nil {
 | 
					 | 
				
			||||||
				job.ImgURL = "data:image/png;base64," + imageData
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		jobs = append(jobs, job)
 | 
							jobs = append(jobs, job)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
 | 
				
			|||||||
func (h *SunoHandler) Create(c *gin.Context) {
 | 
					func (h *SunoHandler) Create(c *gin.Context) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
 | 
							ClientId     string `json:"client_id"`
 | 
				
			||||||
		Prompt       string `json:"prompt"`
 | 
							Prompt       string `json:"prompt"`
 | 
				
			||||||
		Instrumental bool   `json:"instrumental"`
 | 
							Instrumental bool   `json:"instrumental"`
 | 
				
			||||||
		Lyrics       string `json:"lyrics"`
 | 
							Lyrics       string `json:"lyrics"`
 | 
				
			||||||
@@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 创建任务
 | 
						// 创建任务
 | 
				
			||||||
	h.sunoService.PushTask(types.SunoTask{
 | 
						h.sunoService.PushTask(types.SunoTask{
 | 
				
			||||||
 | 
							ClientId:     data.ClientId,
 | 
				
			||||||
		Id:           job.Id,
 | 
							Id:           job.Id,
 | 
				
			||||||
		UserId:       job.UserId,
 | 
							UserId:       job.UserId,
 | 
				
			||||||
		Type:         job.Type,
 | 
							Type:         job.Type,
 | 
				
			||||||
@@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	client := h.sunoService.Clients.Get(uint(job.UserId))
 | 
					 | 
				
			||||||
	if client != nil {
 | 
					 | 
				
			||||||
		_ = client.Send([]byte("Task Updated"))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
 | 
				
			|||||||
		resp.ERROR(c, types.InvalidArgs)
 | 
							resp.ERROR(c, types.InvalidArgs)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
 | 
						content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		resp.ERROR(c, err.Error())
 | 
							resp.ERROR(c, err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
 | 
				
			|||||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
					func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
 | 
							ClientId      string `json:"client_id"`
 | 
				
			||||||
		Prompt        string `json:"prompt"`
 | 
							Prompt        string `json:"prompt"`
 | 
				
			||||||
		FirstFrameImg string `json:"first_frame_img,omitempty"`
 | 
							FirstFrameImg string `json:"first_frame_img,omitempty"`
 | 
				
			||||||
		EndFrameImg   string `json:"end_frame_img,omitempty"`
 | 
							EndFrameImg   string `json:"end_frame_img,omitempty"`
 | 
				
			||||||
@@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 创建任务
 | 
						// 创建任务
 | 
				
			||||||
	h.videoService.PushTask(types.VideoTask{
 | 
						h.videoService.PushTask(types.VideoTask{
 | 
				
			||||||
		Id:     job.Id,
 | 
							ClientId: data.ClientId,
 | 
				
			||||||
		UserId: userId,
 | 
							Id:       job.Id,
 | 
				
			||||||
		Type:   types.VideoLuma,
 | 
							UserId:   userId,
 | 
				
			||||||
		Prompt: data.Prompt,
 | 
							Type:     types.VideoLuma,
 | 
				
			||||||
		Params: params,
 | 
							Prompt:   data.Prompt,
 | 
				
			||||||
 | 
							Params:   params,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// update user's power
 | 
						// update user's power
 | 
				
			||||||
@@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
				
			|||||||
		resp.ERROR(c, err.Error())
 | 
							resp.ERROR(c, err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := h.videoService.Clients.Get(uint(job.UserId))
 | 
					 | 
				
			||||||
	if client != nil {
 | 
					 | 
				
			||||||
		_ = client.Send([]byte("Task Updated"))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// 只有失败或者超时的任务才能删除
 | 
						// 只有失败或者超时的任务才能删除
 | 
				
			||||||
	if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
 | 
						if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
 | 
				
			||||||
		resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
 | 
							resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
 | 
				
			|||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			logger.Infof("Receive a message:%+v", message)
 | 
								logger.Debugf("Receive a message:%+v", message)
 | 
				
			||||||
			if message.Type == types.MsgTypePing {
 | 
								if message.Type == types.MsgTypePing {
 | 
				
			||||||
				utils.SendChannelMsg(client, types.ChPing, "pong")
 | 
									utils.SendChannelMsg(client, types.ChPing, "pong")
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,19 +34,21 @@ type Service struct {
 | 
				
			|||||||
	uploadManager *oss.UploaderManager
 | 
						uploadManager *oss.UploaderManager
 | 
				
			||||||
	taskQueue     *store.RedisQueue
 | 
						taskQueue     *store.RedisQueue
 | 
				
			||||||
	notifyQueue   *store.RedisQueue
 | 
						notifyQueue   *store.RedisQueue
 | 
				
			||||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
					 | 
				
			||||||
	userService   *service.UserService
 | 
						userService   *service.UserService
 | 
				
			||||||
 | 
						wsService     *service.WebsocketService
 | 
				
			||||||
 | 
						clientIds     map[uint]string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
 | 
					func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
 | 
				
			||||||
	return &Service{
 | 
						return &Service{
 | 
				
			||||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
							httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
				
			||||||
		db:            db,
 | 
							db:            db,
 | 
				
			||||||
		taskQueue:     store.NewRedisQueue("DallE_Task_Queue", redisCli),
 | 
							taskQueue:     store.NewRedisQueue("DallE_Task_Queue", redisCli),
 | 
				
			||||||
		notifyQueue:   store.NewRedisQueue("DallE_Notify_Queue", redisCli),
 | 
							notifyQueue:   store.NewRedisQueue("DallE_Notify_Queue", redisCli),
 | 
				
			||||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
							wsService:     wsService,
 | 
				
			||||||
		uploadManager: manager,
 | 
							uploadManager: manager,
 | 
				
			||||||
		userService:   userService,
 | 
							userService:   userService,
 | 
				
			||||||
 | 
							clientIds:     map[uint]string{},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -67,6 +69,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			logger.Infof("handle a new DALL-E task: %+v", task)
 | 
								logger.Infof("handle a new DALL-E task: %+v", task)
 | 
				
			||||||
 | 
								s.clientIds[task.JobId] = task.ClientId
 | 
				
			||||||
			_, err = s.Image(task, false)
 | 
								_, err = s.Image(task, false)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				logger.Errorf("error with image task: %v", err)
 | 
									logger.Errorf("error with image task: %v", err)
 | 
				
			||||||
@@ -74,7 +77,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
					"progress": service.FailTaskProgress,
 | 
										"progress": service.FailTaskProgress,
 | 
				
			||||||
					"err_msg":  err.Error(),
 | 
										"err_msg":  err.Error(),
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
				
			|||||||
	prompt := task.Prompt
 | 
						prompt := task.Prompt
 | 
				
			||||||
	// translate prompt
 | 
						// translate prompt
 | 
				
			||||||
	if utils.HasChinese(prompt) {
 | 
						if utils.HasChinese(prompt) {
 | 
				
			||||||
		content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
 | 
							content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
 | 
				
			||||||
		if err == nil {
 | 
							if err == nil {
 | 
				
			||||||
			prompt = content
 | 
								prompt = content
 | 
				
			||||||
			logger.Debugf("重写后提示词:%s", prompt)
 | 
								logger.Debugf("重写后提示词:%s", prompt)
 | 
				
			||||||
@@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
				
			|||||||
		return "", fmt.Errorf("err with update database: %v", err)
 | 
							return "", fmt.Errorf("err with update database: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
						s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
				
			||||||
	var content string
 | 
						var content string
 | 
				
			||||||
	if sync {
 | 
						if sync {
 | 
				
			||||||
		imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
 | 
							imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
 | 
				
			||||||
@@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			client := s.Clients.Get(uint(message.UserId))
 | 
					
 | 
				
			||||||
 | 
								logger.Debugf("notify message: %+v", message)
 | 
				
			||||||
 | 
								client := s.wsService.Clients.Get(message.ClientId)
 | 
				
			||||||
			if client == nil {
 | 
								if client == nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err = client.Send([]byte(message.Message))
 | 
								utils.SendChannelMsg(client, types.ChDall, message.Message)
 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
 | 
				
			|||||||
	if res.Error != nil {
 | 
						if res.Error != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
 | 
						s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
 | 
				
			||||||
	return imgURL, nil
 | 
						return imgURL, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,7 +58,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// translate prompt
 | 
								// translate prompt
 | 
				
			||||||
			if utils.HasChinese(task.Prompt) {
 | 
								if utils.HasChinese(task.Prompt) {
 | 
				
			||||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
 | 
									content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
 | 
				
			||||||
				if err == nil {
 | 
									if err == nil {
 | 
				
			||||||
					task.Prompt = content
 | 
										task.Prompt = content
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@@ -67,7 +67,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			// translate negative prompt
 | 
								// translate negative prompt
 | 
				
			||||||
			if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
 | 
								if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
 | 
				
			||||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
 | 
									content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
 | 
				
			||||||
				if err == nil {
 | 
									if err == nil {
 | 
				
			||||||
					task.NegPrompt = content
 | 
										task.NegPrompt = content
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								logger.Debugf("receive a new mj notify message: %+v", message)
 | 
				
			||||||
			client := s.wsService.Clients.Get(message.ClientId)
 | 
								client := s.wsService.Clients.Get(message.ClientId)
 | 
				
			||||||
			if client == nil {
 | 
								if client == nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -33,7 +33,6 @@ type Service struct {
 | 
				
			|||||||
	notifyQueue   *store.RedisQueue
 | 
						notifyQueue   *store.RedisQueue
 | 
				
			||||||
	db            *gorm.DB
 | 
						db            *gorm.DB
 | 
				
			||||||
	uploadManager *oss.UploaderManager
 | 
						uploadManager *oss.UploaderManager
 | 
				
			||||||
	leveldb       *store.LevelDB
 | 
					 | 
				
			||||||
	wsService     *service.WebsocketService
 | 
						wsService     *service.WebsocketService
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
 | 
				
			|||||||
		taskQueue:     store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
 | 
							taskQueue:     store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
 | 
				
			||||||
		notifyQueue:   store.NewRedisQueue("StableDiffusion_Queue", redisCli),
 | 
							notifyQueue:   store.NewRedisQueue("StableDiffusion_Queue", redisCli),
 | 
				
			||||||
		db:            db,
 | 
							db:            db,
 | 
				
			||||||
		leveldb:       levelDB,
 | 
					 | 
				
			||||||
		wsService:     wsService,
 | 
							wsService:     wsService,
 | 
				
			||||||
		uploadManager: manager,
 | 
							uploadManager: manager,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -62,7 +60,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// translate prompt
 | 
								// translate prompt
 | 
				
			||||||
			if utils.HasChinese(task.Params.Prompt) {
 | 
								if utils.HasChinese(task.Params.Prompt) {
 | 
				
			||||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
 | 
									content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
 | 
				
			||||||
				if err == nil {
 | 
									if err == nil {
 | 
				
			||||||
					task.Params.Prompt = content
 | 
										task.Params.Prompt = content
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@@ -72,7 +70,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// translate negative prompt
 | 
								// translate negative prompt
 | 
				
			||||||
			if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
								if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
				
			||||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
 | 
									content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
 | 
				
			||||||
				if err == nil {
 | 
									if err == nil {
 | 
				
			||||||
					task.Params.NegPrompt = content
 | 
										task.Params.NegPrompt = content
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// TaskProgressResp 任务进度响应实体
 | 
					// TaskProgressResp 任务进度响应实体
 | 
				
			||||||
type TaskProgressResp struct {
 | 
					type TaskProgressResp struct {
 | 
				
			||||||
	Progress     float64 `json:"progress"`
 | 
						Progress    float64 `json:"progress"`
 | 
				
			||||||
	EtaRelative  float64 `json:"eta_relative"`
 | 
						EtaRelative float64 `json:"eta_relative"`
 | 
				
			||||||
	CurrentImage string  `json:"current_image"`
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Txt2Img 文生图 API
 | 
					// Txt2Img 文生图 API
 | 
				
			||||||
@@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
				
			|||||||
			// task finished
 | 
								// task finished
 | 
				
			||||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
								s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
				
			||||||
			s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
 | 
								s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
 | 
				
			||||||
			// 从 leveldb 中删除预览图片数据
 | 
					 | 
				
			||||||
			_ = s.leveldb.Delete(task.Params.TaskId)
 | 
					 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			err, resp := s.checkTaskProgress(apiKey)
 | 
								err, resp := s.checkTaskProgress(apiKey)
 | 
				
			||||||
@@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
				
			|||||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
									s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
				
			||||||
				// 发送更新状态信号
 | 
									// 发送更新状态信号
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
 | 
				
			||||||
				// 保存预览图片数据
 | 
					 | 
				
			||||||
				if resp.CurrentImage != "" {
 | 
					 | 
				
			||||||
					_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			time.Sleep(time.Second)
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								logger.Debugf("notify message: %+v", message)
 | 
				
			||||||
			client := s.wsService.Clients.Get(message.ClientId)
 | 
								client := s.wsService.Clients.Get(message.ClientId)
 | 
				
			||||||
			if client == nil {
 | 
								if client == nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,17 +34,19 @@ type Service struct {
 | 
				
			|||||||
	uploadManager *oss.UploaderManager
 | 
						uploadManager *oss.UploaderManager
 | 
				
			||||||
	taskQueue     *store.RedisQueue
 | 
						taskQueue     *store.RedisQueue
 | 
				
			||||||
	notifyQueue   *store.RedisQueue
 | 
						notifyQueue   *store.RedisQueue
 | 
				
			||||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
						wsService     *service.WebsocketService
 | 
				
			||||||
 | 
						clientIds     map[string]string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
					func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
 | 
				
			||||||
	return &Service{
 | 
						return &Service{
 | 
				
			||||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
							httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
				
			||||||
		db:            db,
 | 
							db:            db,
 | 
				
			||||||
		taskQueue:     store.NewRedisQueue("Suno_Task_Queue", redisCli),
 | 
							taskQueue:     store.NewRedisQueue("Suno_Task_Queue", redisCli),
 | 
				
			||||||
		notifyQueue:   store.NewRedisQueue("Suno_Notify_Queue", redisCli),
 | 
							notifyQueue:   store.NewRedisQueue("Suno_Notify_Queue", redisCli),
 | 
				
			||||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
					 | 
				
			||||||
		uploadManager: manager,
 | 
							uploadManager: manager,
 | 
				
			||||||
 | 
							wsService:     wsService,
 | 
				
			||||||
 | 
							clientIds:     map[string]string{},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -96,7 +98,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
					"err_msg":  err.Error(),
 | 
										"err_msg":  err.Error(),
 | 
				
			||||||
					"progress": service.FailTaskProgress,
 | 
										"progress": service.FailTaskProgress,
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -105,6 +107,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
				"task_id": r.Data,
 | 
									"task_id": r.Data,
 | 
				
			||||||
				"channel": r.Channel,
 | 
									"channel": r.Channel,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 | 
								s.clientIds[r.Data] = task.ClientId
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			client := s.Clients.Get(uint(message.UserId))
 | 
								logger.Debugf("notify message: %+v", message)
 | 
				
			||||||
 | 
								logger.Debugf("client id: %+v", s.wsService.Clients)
 | 
				
			||||||
 | 
								client := s.wsService.Clients.Get(message.ClientId)
 | 
				
			||||||
 | 
								logger.Debugf("%+v", client)
 | 
				
			||||||
			if client == nil {
 | 
								if client == nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err = client.Send([]byte(message.Message))
 | 
								utils.SendChannelMsg(client, types.ChSuno, message.Message)
 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
 | 
				
			|||||||
				v.AudioURL = audioURL
 | 
									v.AudioURL = audioURL
 | 
				
			||||||
				v.Progress = 100
 | 
									v.Progress = 100
 | 
				
			||||||
				s.db.Updates(&v)
 | 
									s.db.Updates(&v)
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			time.Sleep(time.Second * 10)
 | 
								time.Sleep(time.Second * 10)
 | 
				
			||||||
@@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
 | 
				
			|||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					tx.Commit()
 | 
										tx.Commit()
 | 
				
			||||||
 | 
										s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
 | 
				
			||||||
				} else if task.Data.FailReason != "" {
 | 
									} else if task.Data.FailReason != "" {
 | 
				
			||||||
					job.Progress = service.FailTaskProgress
 | 
										job.Progress = service.FailTaskProgress
 | 
				
			||||||
					job.ErrMsg = task.Data.FailReason
 | 
										job.ErrMsg = task.Data.FailReason
 | 
				
			||||||
					s.db.Updates(&job)
 | 
										s.db.Updates(&job)
 | 
				
			||||||
					s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
										s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,17 +34,19 @@ type Service struct {
 | 
				
			|||||||
	uploadManager *oss.UploaderManager
 | 
						uploadManager *oss.UploaderManager
 | 
				
			||||||
	taskQueue     *store.RedisQueue
 | 
						taskQueue     *store.RedisQueue
 | 
				
			||||||
	notifyQueue   *store.RedisQueue
 | 
						notifyQueue   *store.RedisQueue
 | 
				
			||||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
						wsService     *service.WebsocketService
 | 
				
			||||||
 | 
						clientIds     map[uint]string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
					func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
 | 
				
			||||||
	return &Service{
 | 
						return &Service{
 | 
				
			||||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
							httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
				
			||||||
		db:            db,
 | 
							db:            db,
 | 
				
			||||||
		taskQueue:     store.NewRedisQueue("Video_Task_Queue", redisCli),
 | 
							taskQueue:     store.NewRedisQueue("Video_Task_Queue", redisCli),
 | 
				
			||||||
		notifyQueue:   store.NewRedisQueue("Video_Notify_Queue", redisCli),
 | 
							notifyQueue:   store.NewRedisQueue("Video_Notify_Queue", redisCli),
 | 
				
			||||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
							wsService:     wsService,
 | 
				
			||||||
		uploadManager: manager,
 | 
							uploadManager: manager,
 | 
				
			||||||
 | 
							clientIds:     map[uint]string{},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -85,7 +87,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// translate prompt
 | 
								// translate prompt
 | 
				
			||||||
			if utils.HasChinese(task.Prompt) {
 | 
								if utils.HasChinese(task.Prompt) {
 | 
				
			||||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
 | 
									content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
 | 
				
			||||||
				if err == nil {
 | 
									if err == nil {
 | 
				
			||||||
					task.Prompt = content
 | 
										task.Prompt = content
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
@@ -93,6 +95,10 @@ func (s *Service) Run() {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if task.ClientId != "" {
 | 
				
			||||||
 | 
									s.clientIds[task.Id] = task.ClientId
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var r LumaRespVo
 | 
								var r LumaRespVo
 | 
				
			||||||
			r, err = s.LumaCreate(task)
 | 
								r, err = s.LumaCreate(task)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
@@ -105,7 +111,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					logger.Errorf("update task with error: %v", err)
 | 
										logger.Errorf("update task with error: %v", err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			client := s.Clients.Get(uint(message.UserId))
 | 
								logger.Debugf("Receive notify message: %+v", message)
 | 
				
			||||||
 | 
								client := s.wsService.Clients.Get(message.ClientId)
 | 
				
			||||||
			if client == nil {
 | 
								if client == nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err = client.Send([]byte(message.Message))
 | 
								utils.SendChannelMsg(client, types.ChLuma, message.Message)
 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() {
 | 
				
			|||||||
				v.VideoURL = videoURL
 | 
									v.VideoURL = videoURL
 | 
				
			||||||
				v.Progress = 100
 | 
									v.Progress = 100
 | 
				
			||||||
				s.db.Updates(&v)
 | 
									s.db.Updates(&v)
 | 
				
			||||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
									s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			time.Sleep(time.Second * 10)
 | 
								time.Sleep(time.Second * 10)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,18 +45,25 @@ type apiRes struct {
 | 
				
			|||||||
	} `json:"choices"`
 | 
						} `json:"choices"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
 | 
					func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
 | 
				
			||||||
	var apiKey model.ApiKey
 | 
					 | 
				
			||||||
	res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
 | 
					 | 
				
			||||||
	if res.Error != nil {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	messages := make([]interface{}, 1)
 | 
						messages := make([]interface{}, 1)
 | 
				
			||||||
	messages[0] = types.Message{
 | 
						messages[0] = types.Message{
 | 
				
			||||||
		Role:    "user",
 | 
							Role:    "user",
 | 
				
			||||||
		Content: prompt,
 | 
							Content: prompt,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return SendOpenAIMessage(db, messages, modelName, keyId)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
 | 
				
			||||||
 | 
						var apiKey model.ApiKey
 | 
				
			||||||
 | 
						session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
 | 
				
			||||||
 | 
						if keyId > 0 {
 | 
				
			||||||
 | 
							session = session.Where("id", keyId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err := session.First(&apiKey).Error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var response apiRes
 | 
						var response apiRes
 | 
				
			||||||
	client := req.C()
 | 
						client := req.C()
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										1
									
								
								database/update-v4.1.5.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								database/update-v4.1.5.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					ALTER TABLE `chatgpt_power_logs` CHANGE `remark` `remark` VARCHAR(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '备注';
 | 
				
			||||||
@@ -208,7 +208,7 @@ import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue";
 | 
				
			|||||||
import {httpGet, httpPost} from "@/utils/http";
 | 
					import {httpGet, httpPost} from "@/utils/http";
 | 
				
			||||||
import {ElMessage, ElMessageBox} from "element-plus";
 | 
					import {ElMessage, ElMessageBox} from "element-plus";
 | 
				
			||||||
import Clipboard from "clipboard";
 | 
					import Clipboard from "clipboard";
 | 
				
			||||||
import {checkSession, getSystemInfo} from "@/store/cache";
 | 
					import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
 | 
				
			||||||
import {useSharedStore} from "@/store/sharedata";
 | 
					import {useSharedStore} from "@/store/sharedata";
 | 
				
			||||||
import TaskList from "@/components/TaskList.vue";
 | 
					import TaskList from "@/components/TaskList.vue";
 | 
				
			||||||
import BackTop from "@/components/BackTop.vue";
 | 
					import BackTop from "@/components/BackTop.vue";
 | 
				
			||||||
@@ -240,6 +240,7 @@ const styles = [
 | 
				
			|||||||
  {name: "自然", value: "natural"}
 | 
					  {name: "自然", value: "natural"}
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
const params = ref({
 | 
					const params = ref({
 | 
				
			||||||
 | 
					  client_id: getClientId(),
 | 
				
			||||||
  quality: "standard",
 | 
					  quality: "standard",
 | 
				
			||||||
  size: "1024x1024",
 | 
					  size: "1024x1024",
 | 
				
			||||||
  style: "vivid",
 | 
					  style: "vivid",
 | 
				
			||||||
@@ -268,14 +269,24 @@ onMounted(() => {
 | 
				
			|||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("获取系统配置失败:" + e.message)
 | 
					    ElMessage.error("获取系统配置失败:" + e.message)
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  store.addMessageHandler("dall",(data) => {
 | 
				
			||||||
 | 
					    // 丢弃无关消息
 | 
				
			||||||
 | 
					    if (data.channel !== "dall" || data.clientId !== getClientId()) {
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (data.body === "FINISH" || data.body === "FAIL") {
 | 
				
			||||||
 | 
					      page.value = 0
 | 
				
			||||||
 | 
					      isOver.value = false
 | 
				
			||||||
 | 
					      fetchFinishJobs()
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    nextTick(() => fetchRunningJobs())
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onUnmounted(() => {
 | 
					onUnmounted(() => {
 | 
				
			||||||
  clipboard.value.destroy()
 | 
					  clipboard.value.destroy()
 | 
				
			||||||
  if (socket.value !== null) {
 | 
					 | 
				
			||||||
    socket.value.close()
 | 
					 | 
				
			||||||
    socket.value = null
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const initData = () => {
 | 
					const initData = () => {
 | 
				
			||||||
@@ -287,51 +298,10 @@ const initData = () => {
 | 
				
			|||||||
    page.value = 0
 | 
					    page.value = 0
 | 
				
			||||||
    fetchRunningJobs()
 | 
					    fetchRunningJobs()
 | 
				
			||||||
    fetchFinishJobs()
 | 
					    fetchFinishJobs()
 | 
				
			||||||
    connect()
 | 
					 | 
				
			||||||
  }).catch(() => {
 | 
					  }).catch(() => {
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const socket = ref(null)
 | 
					 | 
				
			||||||
const heartbeatHandle = ref(null)
 | 
					 | 
				
			||||||
const connect = () => {
 | 
					 | 
				
			||||||
  let host = process.env.VUE_APP_WS_HOST
 | 
					 | 
				
			||||||
  if (host === '') {
 | 
					 | 
				
			||||||
    if (location.protocol === 'https:') {
 | 
					 | 
				
			||||||
      host = 'wss://' + location.host;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      host = 'ws://' + location.host;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const _socket = new WebSocket(host + `/api/dall/client?user_id=${userId.value}`);
 | 
					 | 
				
			||||||
  _socket.addEventListener('open', () => {
 | 
					 | 
				
			||||||
    socket.value = _socket;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('message', event => {
 | 
					 | 
				
			||||||
    if (event.data instanceof Blob) {
 | 
					 | 
				
			||||||
      const reader = new FileReader();
 | 
					 | 
				
			||||||
      reader.readAsText(event.data, "UTF-8")
 | 
					 | 
				
			||||||
      reader.onload = () => {
 | 
					 | 
				
			||||||
        const message = String(reader.result)
 | 
					 | 
				
			||||||
        if (message === "FINISH" || message === "FAIL") {
 | 
					 | 
				
			||||||
          page.value = 0
 | 
					 | 
				
			||||||
          isOver.value = false
 | 
					 | 
				
			||||||
          fetchFinishJobs(page.value)
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        nextTick(() => fetchRunningJobs())
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('close', () => {
 | 
					 | 
				
			||||||
    if (socket.value !== null) {
 | 
					 | 
				
			||||||
      connect()
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  })
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const fetchRunningJobs = () => {
 | 
					const fetchRunningJobs = () => {
 | 
				
			||||||
  if (!isLogin.value) {
 | 
					  if (!isLogin.value) {
 | 
				
			||||||
    return
 | 
					    return
 | 
				
			||||||
@@ -391,6 +361,7 @@ const generate = () => {
 | 
				
			|||||||
  httpPost("/api/dall/image", params.value).then(() => {
 | 
					  httpPost("/api/dall/image", params.value).then(() => {
 | 
				
			||||||
    ElMessage.success("任务执行成功!")
 | 
					    ElMessage.success("任务执行成功!")
 | 
				
			||||||
    power.value -= dallPower.value
 | 
					    power.value -= dallPower.value
 | 
				
			||||||
 | 
					    fetchRunningJobs()
 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("任务执行失败:" + e.message)
 | 
					    ElMessage.error("任务执行失败:" + e.message)
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -55,25 +55,6 @@
 | 
				
			|||||||
    <el-container class="video-container" v-loading="loading" element-loading-background="rgba(100,100,100,0.3)">
 | 
					    <el-container class="video-container" v-loading="loading" element-loading-background="rgba(100,100,100,0.3)">
 | 
				
			||||||
      <h2 class="h-title">你的作品</h2>
 | 
					      <h2 class="h-title">你的作品</h2>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<!--      <el-row :gutter="20" class="videos" v-if="!noData">-->
 | 
					 | 
				
			||||||
<!--        <el-col :span="8" class="item" :key="item.id" v-for="item in videos">-->
 | 
					 | 
				
			||||||
<!--          <div class="video-box" @mouseover="item.playing = true" @mouseout="item.playing = false">-->
 | 
					 | 
				
			||||||
<!--            <img :src="item.cover"  :alt="item.name" v-show="!item.playing"/>-->
 | 
					 | 
				
			||||||
<!--            <video :src="item.url"  preload="auto" :autoplay="true" loop="loop" muted="muted" v-show="item.playing">-->
 | 
					 | 
				
			||||||
<!--              您的浏览器不支持视频播放-->
 | 
					 | 
				
			||||||
<!--            </video>-->
 | 
					 | 
				
			||||||
<!--          </div>-->
 | 
					 | 
				
			||||||
<!--          <div class="video-name">{{item.name}}</div>-->
 | 
					 | 
				
			||||||
<!--          <div class="opts">-->
 | 
					 | 
				
			||||||
<!--            <button class="btn" @click="download(item)" :disabled="item.downloading">-->
 | 
					 | 
				
			||||||
<!--              <i class="iconfont icon-download" v-if="!item.downloading"></i>-->
 | 
					 | 
				
			||||||
<!--              <el-image src="/images/loading.gif" fit="cover" v-else />-->
 | 
					 | 
				
			||||||
<!--              <span>下载</span>-->
 | 
					 | 
				
			||||||
<!--            </button>-->
 | 
					 | 
				
			||||||
<!--          </div>-->
 | 
					 | 
				
			||||||
<!--        </el-col>-->
 | 
					 | 
				
			||||||
<!--      </el-row>-->
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      <div class="list-box" v-if="!noData">
 | 
					      <div class="list-box" v-if="!noData">
 | 
				
			||||||
        <div v-for="item in list" :key="item.id">
 | 
					        <div v-for="item in list" :key="item.id">
 | 
				
			||||||
          <div class="item">
 | 
					          <div class="item">
 | 
				
			||||||
@@ -153,13 +134,14 @@
 | 
				
			|||||||
import {onMounted, reactive, ref} from "vue";
 | 
					import {onMounted, reactive, ref} from "vue";
 | 
				
			||||||
import {CircleCloseFilled} from "@element-plus/icons-vue";
 | 
					import {CircleCloseFilled} from "@element-plus/icons-vue";
 | 
				
			||||||
import {httpDownload, httpPost, httpGet} from "@/utils/http";
 | 
					import {httpDownload, httpPost, httpGet} from "@/utils/http";
 | 
				
			||||||
import {checkSession} from "@/store/cache";
 | 
					import {checkSession, getClientId} from "@/store/cache";
 | 
				
			||||||
import {showMessageError, showMessageOK} from "@/utils/dialog";
 | 
					import {showMessageError, showMessageOK} from "@/utils/dialog";
 | 
				
			||||||
import { replaceImg } from "@/utils/libs"
 | 
					import { replaceImg } from "@/utils/libs"
 | 
				
			||||||
import {ElMessage, ElMessageBox} from "element-plus";
 | 
					import {ElMessage, ElMessageBox} from "element-plus";
 | 
				
			||||||
import BlackSwitch from "@/components/ui/BlackSwitch.vue";
 | 
					import BlackSwitch from "@/components/ui/BlackSwitch.vue";
 | 
				
			||||||
import Generating from "@/components/ui/Generating.vue";
 | 
					import Generating from "@/components/ui/Generating.vue";
 | 
				
			||||||
import BlackDialog from "@/components/ui/BlackDialog.vue";
 | 
					import BlackDialog from "@/components/ui/BlackDialog.vue";
 | 
				
			||||||
 | 
					import {useSharedStore} from "@/store/sharedata";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const showDialog = ref(false)
 | 
					const showDialog = ref(false)
 | 
				
			||||||
const currentVideoUrl = ref('')
 | 
					const currentVideoUrl = ref('')
 | 
				
			||||||
@@ -167,6 +149,7 @@ const row = ref(1)
 | 
				
			|||||||
const images = ref([])
 | 
					const images = ref([])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const formData = reactive({
 | 
					const formData = reactive({
 | 
				
			||||||
 | 
					  client_id: getClientId(),
 | 
				
			||||||
  prompt: '',
 | 
					  prompt: '',
 | 
				
			||||||
  expand_prompt: false,
 | 
					  expand_prompt: false,
 | 
				
			||||||
  loop: false,
 | 
					  loop: false,
 | 
				
			||||||
@@ -174,49 +157,22 @@ const formData = reactive({
 | 
				
			|||||||
  end_frame_img: ''
 | 
					  end_frame_img: ''
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const socket = ref(null)
 | 
					const store = useSharedStore()
 | 
				
			||||||
const userId = ref(0)
 | 
					 | 
				
			||||||
const connect = () => {
 | 
					 | 
				
			||||||
  let host = process.env.VUE_APP_WS_HOST
 | 
					 | 
				
			||||||
  if (host === '') {
 | 
					 | 
				
			||||||
    if (location.protocol === 'https:') {
 | 
					 | 
				
			||||||
      host = 'wss://' + location.host;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      host = 'ws://' + location.host;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const _socket = new WebSocket(host + `/api/video/client?user_id=${userId.value}`);
 | 
					 | 
				
			||||||
  _socket.addEventListener('open', () => {
 | 
					 | 
				
			||||||
    socket.value = _socket;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('message', event => {
 | 
					 | 
				
			||||||
    if (event.data instanceof Blob) {
 | 
					 | 
				
			||||||
      const reader = new FileReader();
 | 
					 | 
				
			||||||
      reader.readAsText(event.data, "UTF-8")
 | 
					 | 
				
			||||||
      reader.onload = () => {
 | 
					 | 
				
			||||||
        const message = String(reader.result)
 | 
					 | 
				
			||||||
        if (message === "FINISH" || message === "FAIL") {
 | 
					 | 
				
			||||||
          fetchData()
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('close', () => {
 | 
					 | 
				
			||||||
    if (socket.value !== null) {
 | 
					 | 
				
			||||||
      connect()
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
onMounted(()=>{
 | 
					onMounted(()=>{
 | 
				
			||||||
  checkSession().then(user => {
 | 
					  checkSession().then(() => {
 | 
				
			||||||
    userId.value = user.id
 | 
					    fetchData(1)
 | 
				
			||||||
    connect()
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  store.addMessageHandler("luma",(data) => {
 | 
				
			||||||
 | 
					    // 丢弃无关消息
 | 
				
			||||||
 | 
					    if (data.channel !== "luma" || data.clientId !== getClientId()) {
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (data.body === "FINISH" || data.body === "FAIL") {
 | 
				
			||||||
 | 
					      fetchData(1)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
  fetchData(1)
 | 
					 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const download = (item) => {
 | 
					const download = (item) => {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,7 +45,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
              <div class="param-line">
 | 
					              <div class="param-line">
 | 
				
			||||||
                <el-button color="#47fff1" :dark="false" round @click="generateAI" :loading="loading">
 | 
					                <el-button color="#47fff1" :dark="false" round @click="generateAI" :loading="loading">
 | 
				
			||||||
                  智能生成思维导图
 | 
					                  生成思维导图
 | 
				
			||||||
                </el-button>
 | 
					                </el-button>
 | 
				
			||||||
              </div>
 | 
					              </div>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -79,10 +79,7 @@
 | 
				
			|||||||
            </el-button>
 | 
					            </el-button>
 | 
				
			||||||
          </div>
 | 
					          </div>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          <div class="markdown" v-if="loading">
 | 
					          <div class="body" id="markmap">
 | 
				
			||||||
            <div :style="{ height: rightBoxHeight + 'px', overflow:'auto',width:'80%' }" v-html="html"></div>
 | 
					 | 
				
			||||||
          </div>
 | 
					 | 
				
			||||||
          <div class="body" id="markmap" v-show="!loading">
 | 
					 | 
				
			||||||
            <svg ref="svgRef" :style="{ height: rightBoxHeight + 'px' }"/>
 | 
					            <svg ref="svgRef" :style="{ height: rightBoxHeight + 'px' }"/>
 | 
				
			||||||
            <div id="toolbar"></div>
 | 
					            <div id="toolbar"></div>
 | 
				
			||||||
          </div>
 | 
					          </div>
 | 
				
			||||||
@@ -94,11 +91,11 @@
 | 
				
			|||||||
</template>
 | 
					</template>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<script setup>
 | 
					<script setup>
 | 
				
			||||||
import {nextTick, onUnmounted, ref} from 'vue';
 | 
					import {nextTick, ref} from 'vue';
 | 
				
			||||||
import {Markmap} from 'markmap-view';
 | 
					import {Markmap} from 'markmap-view';
 | 
				
			||||||
import {Transformer} from 'markmap-lib';
 | 
					import {Transformer} from 'markmap-lib';
 | 
				
			||||||
import {checkSession, getSystemInfo} from "@/store/cache";
 | 
					import {checkSession, getSystemInfo} from "@/store/cache";
 | 
				
			||||||
import {httpGet} from "@/utils/http";
 | 
					import {httpGet, httpPost} from "@/utils/http";
 | 
				
			||||||
import {ElMessage} from "element-plus";
 | 
					import {ElMessage} from "element-plus";
 | 
				
			||||||
import {Download} from "@element-plus/icons-vue";
 | 
					import {Download} from "@element-plus/icons-vue";
 | 
				
			||||||
import {Toolbar} from 'markmap-toolbar';
 | 
					import {Toolbar} from 'markmap-toolbar';
 | 
				
			||||||
@@ -106,11 +103,9 @@ import {useSharedStore} from "@/store/sharedata";
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const leftBoxHeight = ref(window.innerHeight - 105)
 | 
					const leftBoxHeight = ref(window.innerHeight - 105)
 | 
				
			||||||
const rightBoxHeight = ref(window.innerHeight - 115)
 | 
					const rightBoxHeight = ref(window.innerHeight - 115)
 | 
				
			||||||
const title = ref("")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
const prompt = ref("")
 | 
					const prompt = ref("")
 | 
				
			||||||
const text = ref("")
 | 
					const text = ref("")
 | 
				
			||||||
const md = require('markdown-it')({breaks: true});
 | 
					 | 
				
			||||||
const content = ref(text.value)
 | 
					const content = ref(text.value)
 | 
				
			||||||
const html = ref("")
 | 
					const html = ref("")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -118,13 +113,12 @@ const isLogin = ref(false)
 | 
				
			|||||||
const loginUser = ref({power: 0})
 | 
					const loginUser = ref({power: 0})
 | 
				
			||||||
const transformer = new Transformer();
 | 
					const transformer = new Transformer();
 | 
				
			||||||
const store = useSharedStore();
 | 
					const store = useSharedStore();
 | 
				
			||||||
 | 
					const loading = ref(false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const svgRef = ref(null)
 | 
					const svgRef = ref(null)
 | 
				
			||||||
const markMap = ref(null)
 | 
					const markMap = ref(null)
 | 
				
			||||||
const models = ref([])
 | 
					const models = ref([])
 | 
				
			||||||
const modelID = ref(0)
 | 
					const modelID = ref(0)
 | 
				
			||||||
const loading = ref(false)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
getSystemInfo().then(res => {
 | 
					getSystemInfo().then(res => {
 | 
				
			||||||
  text.value = res.data['mark_map_text']
 | 
					  text.value = res.data['mark_map_text']
 | 
				
			||||||
@@ -147,9 +141,7 @@ getSystemInfo().then(res => {
 | 
				
			|||||||
const initData = () => {
 | 
					const initData = () => {
 | 
				
			||||||
  httpGet("/api/model/list").then(res => {
 | 
					  httpGet("/api/model/list").then(res => {
 | 
				
			||||||
    for (let v of res.data) {
 | 
					    for (let v of res.data) {
 | 
				
			||||||
      if (v.value.indexOf("gpt-4-gizmo") === -1) {
 | 
					      models.value.push(v)
 | 
				
			||||||
        models.value.push(v)
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    modelID.value = models.value[0].id
 | 
					    modelID.value = models.value[0].id
 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
@@ -159,7 +151,6 @@ const initData = () => {
 | 
				
			|||||||
  checkSession().then(user => {
 | 
					  checkSession().then(user => {
 | 
				
			||||||
    loginUser.value = user
 | 
					    loginUser.value = user
 | 
				
			||||||
    isLogin.value = true
 | 
					    isLogin.value = true
 | 
				
			||||||
    connect(user.id)
 | 
					 | 
				
			||||||
  }).catch(() => {
 | 
					  }).catch(() => {
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -191,74 +182,11 @@ const processContent = (text) => {
 | 
				
			|||||||
  return arr.join("\n")
 | 
					  return arr.join("\n")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onUnmounted(() => {
 | 
					 | 
				
			||||||
  if (socket.value !== null) {
 | 
					 | 
				
			||||||
    socket.value.close()
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  socket.value = null
 | 
					 | 
				
			||||||
})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
window.onresize = () => {
 | 
					window.onresize = () => {
 | 
				
			||||||
  leftBoxHeight.value = window.innerHeight - 145
 | 
					  leftBoxHeight.value = window.innerHeight - 145
 | 
				
			||||||
  rightBoxHeight.value = window.innerHeight - 85
 | 
					  rightBoxHeight.value = window.innerHeight - 85
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const socket = ref(null)
 | 
					 | 
				
			||||||
const connect = (userId) => {
 | 
					 | 
				
			||||||
  if (socket.value !== null) {
 | 
					 | 
				
			||||||
    socket.value.close()
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  let host = process.env.VUE_APP_WS_HOST
 | 
					 | 
				
			||||||
  if (host === '') {
 | 
					 | 
				
			||||||
    if (location.protocol === 'https:') {
 | 
					 | 
				
			||||||
      host = 'wss://' + location.host;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      host = 'ws://' + location.host;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const _socket = new WebSocket(host + `/api/markMap/client?user_id=${userId}&model_id=${modelID.value}`);
 | 
					 | 
				
			||||||
  _socket.addEventListener('open', () => {
 | 
					 | 
				
			||||||
    socket.value = _socket;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('message', event => {
 | 
					 | 
				
			||||||
    if (event.data instanceof Blob) {
 | 
					 | 
				
			||||||
      const reader = new FileReader();
 | 
					 | 
				
			||||||
      reader.readAsText(event.data, "UTF-8")
 | 
					 | 
				
			||||||
      const model = getModelById(modelID.value)
 | 
					 | 
				
			||||||
      reader.onload = () => {
 | 
					 | 
				
			||||||
        const data = JSON.parse(String(reader.result))
 | 
					 | 
				
			||||||
        switch (data.type) {
 | 
					 | 
				
			||||||
          case "content":
 | 
					 | 
				
			||||||
            text.value += data.content
 | 
					 | 
				
			||||||
            html.value = md.render(processContent(text.value))
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
          case "end":
 | 
					 | 
				
			||||||
            loading.value = false
 | 
					 | 
				
			||||||
            content.value = processContent(text.value)
 | 
					 | 
				
			||||||
            loginUser.value.power -= model.power
 | 
					 | 
				
			||||||
            nextTick(() => update())
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
          case "error":
 | 
					 | 
				
			||||||
            loading.value = false
 | 
					 | 
				
			||||||
            ElMessage.error(data.content)
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('close', () => {
 | 
					 | 
				
			||||||
    loading.value = false
 | 
					 | 
				
			||||||
    checkSession().then(() => {
 | 
					 | 
				
			||||||
      connect(userId)
 | 
					 | 
				
			||||||
    }).catch(() => {
 | 
					 | 
				
			||||||
    })
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const generate = () => {
 | 
					const generate = () => {
 | 
				
			||||||
  text.value = content.value
 | 
					  text.value = content.value
 | 
				
			||||||
  update()
 | 
					  update()
 | 
				
			||||||
@@ -276,19 +204,26 @@ const generateAI = () => {
 | 
				
			|||||||
    return
 | 
					    return
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  loading.value = true
 | 
					  loading.value = true
 | 
				
			||||||
  socket.value.send(JSON.stringify({type: "message", content: prompt.value}))
 | 
					  httpPost("/api/markMap/gen", {
 | 
				
			||||||
}
 | 
					    prompt:prompt.value,
 | 
				
			||||||
 | 
					    model_id: modelID.value
 | 
				
			||||||
const changeModel = () => {
 | 
					  }).then(res => {
 | 
				
			||||||
  if (socket.value !== null) {
 | 
					    text.value = res.data
 | 
				
			||||||
    socket.value.send(JSON.stringify({type: "model_id", content: modelID.value}))
 | 
					    content.value = processContent(text.value)
 | 
				
			||||||
  }
 | 
					    const model = getModelById(modelID.value)
 | 
				
			||||||
 | 
					    loginUser.value.power -= model.power
 | 
				
			||||||
 | 
					    nextTick(() => update())
 | 
				
			||||||
 | 
					    loading.value = false
 | 
				
			||||||
 | 
					  }).catch(e => {
 | 
				
			||||||
 | 
					    ElMessage.error("生成思维导图失败:" + e.message)
 | 
				
			||||||
 | 
					    loading.value = false
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const getModelById = (modelId) => {
 | 
					const getModelById = (modelId) => {
 | 
				
			||||||
  for (let e of models.value) {
 | 
					  for (let m of models.value) {
 | 
				
			||||||
    if (e.id === modelId) {
 | 
					    if (m.id === modelId) {
 | 
				
			||||||
      return e
 | 
					      return m
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -300,13 +300,14 @@ import MusicPlayer from "@/components/MusicPlayer.vue";
 | 
				
			|||||||
import {compact} from "lodash";
 | 
					import {compact} from "lodash";
 | 
				
			||||||
import {httpDownload, httpGet, httpPost} from "@/utils/http";
 | 
					import {httpDownload, httpGet, httpPost} from "@/utils/http";
 | 
				
			||||||
import {showMessageError, showMessageOK} from "@/utils/dialog";
 | 
					import {showMessageError, showMessageOK} from "@/utils/dialog";
 | 
				
			||||||
import {checkSession} from "@/store/cache";
 | 
					import {checkSession, getClientId} from "@/store/cache";
 | 
				
			||||||
import {ElMessage, ElMessageBox} from "element-plus";
 | 
					import {ElMessage, ElMessageBox} from "element-plus";
 | 
				
			||||||
import {formatTime, replaceImg} from "@/utils/libs";
 | 
					import {formatTime, replaceImg} from "@/utils/libs";
 | 
				
			||||||
import Clipboard from "clipboard";
 | 
					import Clipboard from "clipboard";
 | 
				
			||||||
import BlackDialog from "@/components/ui/BlackDialog.vue";
 | 
					import BlackDialog from "@/components/ui/BlackDialog.vue";
 | 
				
			||||||
import Compressor from "compressorjs";
 | 
					import Compressor from "compressorjs";
 | 
				
			||||||
import Generating from "@/components/ui/Generating.vue";
 | 
					import Generating from "@/components/ui/Generating.vue";
 | 
				
			||||||
 | 
					import {useSharedStore} from "@/store/sharedata";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const winHeight = ref(window.innerHeight - 50)
 | 
					const winHeight = ref(window.innerHeight - 50)
 | 
				
			||||||
const custom = ref(false)
 | 
					const custom = ref(false)
 | 
				
			||||||
@@ -333,6 +334,7 @@ const tags = ref([
 | 
				
			|||||||
  {label: "嘻哈", value: "hip hop"},
 | 
					  {label: "嘻哈", value: "hip hop"},
 | 
				
			||||||
])
 | 
					])
 | 
				
			||||||
const data = ref({
 | 
					const data = ref({
 | 
				
			||||||
 | 
					  client_id: getClientId(),
 | 
				
			||||||
  model: "chirp-v3-0",
 | 
					  model: "chirp-v3-0",
 | 
				
			||||||
  tags: "",
 | 
					  tags: "",
 | 
				
			||||||
  lyrics: "",
 | 
					  lyrics: "",
 | 
				
			||||||
@@ -354,45 +356,7 @@ const refSong = ref(null)
 | 
				
			|||||||
const showDialog = ref(false)
 | 
					const showDialog = ref(false)
 | 
				
			||||||
const editData = ref({title:"",cover:"",id:0})
 | 
					const editData = ref({title:"",cover:"",id:0})
 | 
				
			||||||
const promptPlaceholder = ref('请在这里输入你自己写的歌词...')
 | 
					const promptPlaceholder = ref('请在这里输入你自己写的歌词...')
 | 
				
			||||||
 | 
					const store = useSharedStore()
 | 
				
			||||||
const socket = ref(null)
 | 
					 | 
				
			||||||
const userId = ref(0)
 | 
					 | 
				
			||||||
const connect = () => {
 | 
					 | 
				
			||||||
  let host = process.env.VUE_APP_WS_HOST
 | 
					 | 
				
			||||||
  if (host === '') {
 | 
					 | 
				
			||||||
    if (location.protocol === 'https:') {
 | 
					 | 
				
			||||||
      host = 'wss://' + location.host;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      host = 'ws://' + location.host;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const _socket = new WebSocket(host + `/api/suno/client?user_id=${userId.value}`);
 | 
					 | 
				
			||||||
  _socket.addEventListener('open', () => {
 | 
					 | 
				
			||||||
    socket.value = _socket;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('message', event => {
 | 
					 | 
				
			||||||
    if (event.data instanceof Blob) {
 | 
					 | 
				
			||||||
      const reader = new FileReader();
 | 
					 | 
				
			||||||
      reader.readAsText(event.data, "UTF-8")
 | 
					 | 
				
			||||||
      reader.onload = () => {
 | 
					 | 
				
			||||||
        const message = String(reader.result)
 | 
					 | 
				
			||||||
        console.log(message)
 | 
					 | 
				
			||||||
        if (message === "FINISH" || message === "FAIL") {
 | 
					 | 
				
			||||||
          fetchData()
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('close', () => {
 | 
					 | 
				
			||||||
    if (socket.value !== null) {
 | 
					 | 
				
			||||||
      connect()
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const clipboard = ref(null)
 | 
					const clipboard = ref(null)
 | 
				
			||||||
onMounted(() => {
 | 
					onMounted(() => {
 | 
				
			||||||
  clipboard.value = new Clipboard('.copy-link');
 | 
					  clipboard.value = new Clipboard('.copy-link');
 | 
				
			||||||
@@ -405,10 +369,19 @@ onMounted(() => {
 | 
				
			|||||||
  })
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  checkSession().then(user => {
 | 
					  checkSession().then(user => {
 | 
				
			||||||
    userId.value = user.id
 | 
					    fetchData(1)
 | 
				
			||||||
    connect()
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  store.addMessageHandler("suno",(data) => {
 | 
				
			||||||
 | 
					    // 丢弃无关消息
 | 
				
			||||||
 | 
					    if (data.channel !== "suno" || data.clientId !== getClientId()) {
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (data.body === "FINISH" || data.body === "FAIL") {
 | 
				
			||||||
 | 
					      fetchData(1)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
  fetchData(1)
 | 
					 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onUnmounted(() => {
 | 
					onUnmounted(() => {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user