websocket api refactor is ready

This commit is contained in:
RockYang
2024-09-29 19:28:47 +08:00
parent 8093a3eeb2
commit 1a1734abf0
19 changed files with 210 additions and 464 deletions

View File

@@ -72,18 +72,20 @@ type SdTaskParams struct {
// DallTask DALL-E task
type DallTask struct {
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
ClientId string `json:"client_id"`
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
}
type SunoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
@@ -107,13 +109,14 @@ const (
)
type VideoTask struct {
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
ClientId string `json:"client_id"`
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
}
type VideoParams struct {

View File

@@ -84,19 +84,15 @@ func (h *DallJobHandler) Image(c *gin.Context) {
}
h.dallService.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
ClientId: data.ClientId,
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.dallService.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}

View File

@@ -8,23 +8,15 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// MarkMapHandler 生成思维导图
@@ -44,23 +36,33 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
// Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
ModelId int `json:"model_id"`
}
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
userId := h.GetLoginUserId(c)
var user model.User
res := h.DB.Model(&model.User{}).First(&user, userId)
if res.Error != nil {
return fmt.Errorf("error with query user info: %v", res.Error)
err := h.DB.Where("id", userId).First(&user, userId).Error
if err != nil {
resp.ERROR(c, "error with query user info")
return
}
var chatModel model.ChatModel
res = h.DB.Where("id", modelId).First(&chatModel)
if res.Error != nil {
return fmt.Errorf("error with query chat model: %v", res.Error)
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if err != nil {
resp.ERROR(c, "error with query chat model")
return
}
if user.Power < chatModel.Power {
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power)
resp.ERROR(c, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型算力%d", user.Power, chatModel.Power))
return
}
messages := make([]interface{}, 0)
@@ -82,117 +84,27 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
### 支付宝
### 微信
另外,除此之外不要任何解释性语句。
请直接生成结果,不要任何解释性语句。
`})
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", prompt)})
var req = types.ApiRequest{
Model: chatModel.Value,
Stream: true,
Messages: messages,
}
var apiKey model.ApiKey
response, err := h.doRequest(req, chatModel, &apiKey)
messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图要求结构清晰有条理", data.Prompt)})
content, err := utils.SendOpenAIMessage(h.DB, messages, chatModel.Value, chatModel.KeyId)
if err != nil {
return fmt.Errorf("请求 OpenAI API 失败: %s", err)
}
defer response.Body.Close()
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
utils.SendMsg(client, types.ReplyMessage{
Type: types.MsgTypeText,
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API 失败: %s", err))
return
}
// 扣减算力
if chatModel.Power > 0 {
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
err = h.userService.DecreasePower(int(userId), chatModel.Power, model.PowerLog{
Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
})
if err != nil {
return err
resp.ERROR(c, "error with save power log, "+err.Error())
return
}
}
return nil
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
if chatModel.KeyId > 0 {
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
if err != nil {
return nil, err
}
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
if len(apiKey.ProxyURL) > 5 { // 使用代理
proxy, _ := url.Parse(apiKey.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
} else {
client = http.DefaultClient
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
return client.Do(request)
resp.SUCCESS(c, content)
}

View File

@@ -232,15 +232,6 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
if err != nil {
continue
}
if item.Progress < 100 {
// 从 leveldb 中获取图片预览数据
var imageData string
err = h.leveldb.Get(item.TaskId, &imageData)
if err == nil {
job.ImgURL = "data:image/png;base64," + imageData
}
}
jobs = append(jobs, job)
}

View File

@@ -45,6 +45,7 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
func (h *SunoHandler) Create(c *gin.Context) {
var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"`
@@ -115,6 +116,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
// 创建任务
h.sunoService.PushTask(types.SunoTask{
ClientId: data.ClientId,
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
@@ -141,10 +143,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
return
}
client := h.sunoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
@@ -365,7 +363,7 @@ func (h *SunoHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini", 0)
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -45,6 +45,7 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
@@ -95,11 +96,12 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
// 创建任务
h.videoService.PushTask(types.VideoTask{
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
ClientId: data.ClientId,
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
})
// update user's power
@@ -112,11 +114,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
resp.ERROR(c, err.Error())
return
}
client := h.videoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
@@ -175,7 +172,7 @@ func (h *VideoHandler) Remove(c *gin.Context) {
return
}
// 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
if !(job.Progress == service.FailTaskProgress || time.Now().After(job.CreatedAt.Add(time.Minute*30))) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return
}

View File

@@ -77,7 +77,7 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
continue
}
logger.Infof("Receive a message:%+v", message)
logger.Debugf("Receive a message:%+v", message)
if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong")
continue

View File

@@ -34,19 +34,21 @@ type Service struct {
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
userService *service.UserService
wsService *service.WebsocketService
clientIds map[uint]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploadManager: manager,
userService: userService,
clientIds: map[uint]string{},
}
}
@@ -67,6 +69,7 @@ func (s *Service) Run() {
continue
}
logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.JobId] = task.ClientId
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
@@ -74,7 +77,7 @@ func (s *Service) Run() {
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
}
}
}()
@@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
@@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", fmt.Errorf("err with update database: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
utils.SendChannelMsg(client, types.ChDall, message.Message)
}
}()
}
@@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil
}

View File

@@ -58,7 +58,7 @@ func (s *Service) Run() {
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil {
task.Prompt = content
} else {
@@ -67,7 +67,7 @@ func (s *Service) Run() {
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini", 0)
if err == nil {
task.NegPrompt = content
} else {
@@ -169,6 +169,7 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
logger.Debugf("receive a new mj notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue

View File

@@ -33,7 +33,6 @@ type Service struct {
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
leveldb *store.LevelDB
wsService *service.WebsocketService
}
@@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db,
leveldb: levelDB,
wsService: wsService,
uploadManager: manager,
}
@@ -62,7 +60,7 @@ func (s *Service) Run() {
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
if err == nil {
task.Params.Prompt = content
} else {
@@ -72,7 +70,7 @@ func (s *Service) Run() {
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
if err == nil {
task.Params.NegPrompt = content
} else {
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
// TaskProgressResp 任务进度响应实体
type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
}
// Txt2Img 文生图 API
@@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
err, resp := s.checkTaskProgress(apiKey)
@@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
}
time.Sleep(time.Second)
}
@@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue

View File

@@ -34,17 +34,19 @@ type Service struct {
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
wsService *service.WebsocketService
clientIds map[string]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
wsService: wsService,
clientIds: map[string]string{},
}
}
@@ -96,7 +98,7 @@ func (s *Service) Run() {
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
@@ -105,6 +107,7 @@ func (s *Service) Run() {
"task_id": r.Data,
"channel": r.Channel,
})
s.clientIds[r.Data] = task.ClientId
}
}()
}
@@ -271,14 +274,14 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
logger.Debugf("client id: %+v", s.wsService.Clients)
client := s.wsService.Clients.Get(message.ClientId)
logger.Debugf("%+v", client)
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
utils.SendChannelMsg(client, types.ChSuno, message.Message)
}
}()
}
@@ -311,7 +314,7 @@ func (s *Service) DownloadFiles() {
v.AudioURL = audioURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
@@ -377,12 +380,12 @@ func (s *Service) SyncTaskProgress() {
}
}
tx.Commit()
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
} else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
}
}

View File

@@ -34,17 +34,19 @@ type Service struct {
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
wsService *service.WebsocketService
clientIds map[uint]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploadManager: manager,
clientIds: map[uint]string{},
}
}
@@ -85,7 +87,7 @@ func (s *Service) Run() {
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini", 0)
if err == nil {
task.Prompt = content
} else {
@@ -93,6 +95,10 @@ func (s *Service) Run() {
}
}
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
@@ -105,7 +111,7 @@ func (s *Service) Run() {
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
@@ -190,14 +196,12 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
@@ -237,7 +241,7 @@ func (s *Service) DownloadFiles() {
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)

View File

@@ -45,18 +45,25 @@ type apiRes struct {
} `json:"choices"`
}
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
var apiKey model.ApiKey
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
if res.Error != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
}
func OpenAIRequest(db *gorm.DB, prompt string, modelName string, keyId int) (string, error) {
messages := make([]interface{}, 1)
messages[0] = types.Message{
Role: "user",
Content: prompt,
}
return SendOpenAIMessage(db, messages, modelName, keyId)
}
func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelName string, keyId int) (string, error) {
var apiKey model.ApiKey
session := db.Session(&gorm.Session{}).Where("type", "chat").Where("enabled", true)
if keyId > 0 {
session = session.Where("id", keyId)
}
err := session.First(&apiKey).Error
if err != nil {
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", err)
}
var response apiRes
client := req.C()