mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-25 17:45:58 +08:00
websocket api refactor is ready
This commit is contained in:
@@ -72,18 +72,20 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
ClientId string `json:"client_id"`
|
||||
JobId uint `json:"job_id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
|
||||
Power int `json:"power"`
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
@@ -107,13 +109,14 @@ const (
|
||||
)
|
||||
|
||||
type VideoTask struct {
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
ClientId string `json:"client_id"`
|
||||
Id uint `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
UserId int `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
Params VideoParams `json:"params"`
|
||||
}
|
||||
|
||||
type VideoParams struct {
|
||||
|
||||
@@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.dallService.PushTask(types.DallTask{
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
ClientId: data.ClientId,
|
||||
JobId: job.Id,
|
||||
UserId: uint(userId),
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
Power: job.Power,
|
||||
})
|
||||
|
||||
client := h.dallService.Clients.Get(job.UserId)
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,23 +8,15 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MarkMapHandler 生成思维导图
|
||||
@@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
ModelId int `json:"model_id"`
|
||||
}
|
||||
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, userId)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query user info: %v", res.Error)
|
||||
err := h.DB.Where("id", userId).First(&user, userId).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query user info")
|
||||
return
|
||||
}
|
||||
var chatModel model.ChatModel
|
||||
res = h.DB.Where("id", modelId).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with query chat model: %v", res.Error)
|
||||
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with query chat model")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < chatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
|
||||
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power))
|
||||
return
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 0)
|
||||
@@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
### 支付宝
|
||||
### 微信
|
||||
|
||||
另外,除此之外不要任何解释性语句。
|
||||
请直接生成结果,不要任何解释性语句。
|
||||
`})
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)})
|
||||
var req = types.ApiRequest{
|
||||
Model: chatModel.Value,
|
||||
Stream: true,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
response, err := h.doRequest(req, chatModel, &apiKey)
|
||||
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", data.Prompt)})
|
||||
content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
contentType := response.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
// 循环读取 Chunk 消息
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
continue
|
||||
}
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
return fmt.Errorf("error with decode data: %v", line)
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "stop" {
|
||||
break
|
||||
}
|
||||
|
||||
utils.SendMsg(client, types.ReplyMessage{
|
||||
Type: types.MsgTypeText,
|
||||
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
|
||||
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减算力
|
||||
if chatModel.Power > 0 {
|
||||
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
|
||||
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: chatModel.Value,
|
||||
Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
resp.ERROR(c, "error with save power log, "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if chatModel.KeyId > 0 {
|
||||
session = session.Where("id", chatModel.KeyId)
|
||||
} else { // use the last unused key
|
||||
session = session.Where("type", "chat").
|
||||
Where("enabled", true).Order("last_used_at ASC")
|
||||
}
|
||||
|
||||
res := session.First(apiKey)
|
||||
if res.Error != nil {
|
||||
return nil, errors.New("no available key, please import key")
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 创建 HttpClient 请求对象
|
||||
var client *http.Client
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if len(apiKey.ProxyURL) > 5 { // 使用代理
|
||||
proxy, _ := url.Parse(apiKey.ProxyURL)
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
|
||||
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
|
||||
return client.Do(request)
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
@@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if item.Progress < 100 {
|
||||
// 从 leveldb 中获取图片预览数据
|
||||
var imageData string
|
||||
err = h.leveldb.Get(item.TaskId, &imageData)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + imageData
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Instrumental bool `json:"instrumental"`
|
||||
Lyrics string `json:"lyrics"`
|
||||
@@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.sunoService.PushTask(types.SunoTask{
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
@@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client := h.sunoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
|
||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
var data struct {
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImg string `json:"first_frame_img,omitempty"`
|
||||
EndFrameImg string `json:"end_frame_img,omitempty"`
|
||||
@@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
|
||||
// 创建任务
|
||||
h.videoService.PushTask(types.VideoTask{
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
ClientId: data.ClientId,
|
||||
Id: job.Id,
|
||||
UserId: userId,
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
})
|
||||
|
||||
// update user's power
|
||||
@@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
client := h.videoService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 只有失败或者超时的任务才能删除
|
||||
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
|
||||
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
|
||||
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("Receive a message:%+v", message)
|
||||
logger.Debugf("Receive a message:%+v", message)
|
||||
if message.Type == types.MsgTypePing {
|
||||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||||
continue
|
||||
|
||||
@@ -34,19 +34,21 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +69,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
s.clientIds[task.JobId] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
@@ -74,7 +77,7 @@ func (s *Service) Run() {
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
@@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
|
||||
@@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@@ -67,7 +67,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
@@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("receive a new mj notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
|
||||
@@ -33,7 +33,6 @@ type Service struct {
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
leveldb *store.LevelDB
|
||||
wsService *service.WebsocketService
|
||||
}
|
||||
|
||||
@@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
}
|
||||
@@ -62,7 +60,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
@@ -72,7 +70,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
|
||||
|
||||
// TaskProgressResp 任务进度响应实体
|
||||
type TaskProgressResp struct {
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
CurrentImage string `json:"current_image"`
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
@@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
@@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
|
||||
@@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[string]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
wsService: wsService,
|
||||
clientIds: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +98,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -105,6 +107,7 @@ func (s *Service) Run() {
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
s.clientIds[r.Data] = task.ClientId
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
logger.Debugf("client id: %+v", s.wsService.Clients)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
logger.Debugf("%+v", client)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSuno, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
@@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,17 +34,19 @@ type Service struct {
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +87,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@@ -93,6 +95,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
if task.ClientId != "" {
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
}
|
||||
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
@@ -105,7 +111,7 @@ func (s *Service) Run() {
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
logger.Debugf("Receive notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChLuma, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() {
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
|
||||
@@ -45,18 +45,25 @@ type apiRes struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
}
|
||||
return SendOpenAIMessage(db, messages, modelName, keyId)
|
||||
}
|
||||
|
||||
func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
|
||||
if keyId > 0 {
|
||||
session = session.Where("id", keyId)
|
||||
}
|
||||
err := session.First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err)
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
client := req.C()
|
||||
|
||||
Reference in New Issue
Block a user