From d124eddd9d82a93832aa1bd0739a6bab1a4ec72a Mon Sep 17 00:00:00 2001 From: mario Date: Fri, 14 Feb 2025 15:03:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20=E5=8F=AF=E7=81=B5?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/types/config.go | 1 + api/core/types/task.go | 59 ++- api/core/types/web.go | 15 +- api/handler/video_handler.go | 94 ++++- api/main.go | 1 + api/service/types.go | 1 + api/service/video/luma.go | 377 ------------------ api/service/video/video.go | 674 +++++++++++++++++++++++++++++++++ config/config.yaml | 1 + web/src/views/admin/ApiKey.vue | 1 + 10 files changed, 825 insertions(+), 399 deletions(-) delete mode 100644 api/service/video/luma.go create mode 100644 api/service/video/video.go create mode 100644 config/config.yaml diff --git a/api/core/types/config.go b/api/core/types/config.go index 5ce0bdfc..a5679bdc 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -150,6 +150,7 @@ type SystemConfig struct { DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力 SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力 LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力 + KeLingPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力 AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力 PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力 diff --git a/api/core/types/task.go b/api/core/types/task.go index 63c7ec29..769f8fc3 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -73,18 +73,18 @@ type SdTaskParams struct { // DallTask DALL-E task type DallTask struct { - ClientId string `json:"client_id"` - ModelId uint `json:"model_id"` - ModelName string `json:"model_name"` - Id uint `json:"id"` - UserId uint `json:"user_id"` - Prompt string `json:"prompt"` - N int `json:"n"` - Quality string `json:"quality"` - Size string `json:"size"` - Style string `json:"style"` - Power int `json:"power"` - TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID + ClientId string `json:"client_id"` + ModelId uint `json:"model_id"` + ModelName string `json:"model_name"` + Id uint `json:"id"` + UserId uint `json:"user_id"` + Prompt string `json:"prompt"` + N int `json:"n"` + Quality string `json:"quality"` + Size string `json:"size"` + Style string `json:"style"` + Power int `json:"power"` + TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID } type SunoTask struct { @@ -109,6 +109,7 @@ const ( VideoLuma = "luma" VideoRunway = "runway" VideoCog = "cog" + VideoKeLing = "keling" ) type VideoTask struct { @@ -119,11 +120,11 @@ type VideoTask struct { Type string `json:"type"` TaskId string `json:"task_id"` Prompt string `json:"prompt"` // 提示词 - Params VideoParams `json:"params"` + Params interface{} `json:"params"` TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID } -type VideoParams struct { +type LumaVideoParams struct { PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词 Loop bool `json:"loop"` // 是否循环参考图 StartImgURL string `json:"start_img_url"` // 第一帧参考图地址 @@ -133,3 +134,33 @@ type VideoParams struct { Style string `json:"style"` // 风格 Duration int `json:"duration"` // 视频时长(秒) } + +type KeLingVideoParams struct { + TaskType string `json:"task_type"` // 任务类型: text2video/image2video + Model string `json:"model"` // 模型: default/anime + Prompt string `json:"prompt"` // 视频描述 + NegPrompt string `json:"negative_prompt"` // 负面提示词 + CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1) + Mode string `json:"mode"` // 生成模式: std/pro + AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1 + Duration string `json:"duration"` // 视频时长: 5/10 + CameraControl CameraControl `json:"camera_control"` // 摄像机控制 + Image string `json:"image"` // 参考图片URL(image2video) + ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video) +} + +// CameraControl 摄像机控制 +type CameraControl struct { + Type string `json:"type"` // 控制类型: simple/down_back/forward_up/right_turn_forward/left_turn_forward + Config CameraConfig `json:"config"` // 控制参数(仅simple类型时使用) +} + +// CameraConfig 摄像机参数 +type CameraConfig struct { + Horizontal int `json:"horizontal"` // 水平移动(-10到10) + Vertical int `json:"vertical"` // 垂直移动(-10到10) + Pan int `json:"pan"` // 左右旋转(-10到10) + Tilt int `json:"tilt"` // 上下旋转(-10到10) + Roll int `json:"roll"` // 横向翻转(-10到10) + Zoom int `json:"zoom"` // 镜头缩放(-10到10) +} diff --git a/api/core/types/web.go b/api/core/types/web.go index e2b5e636..0546f18e 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -34,13 +34,14 @@ const ( MsgTypeErr = WsMsgType("error") MsgTypePing = WsMsgType("ping") // 心跳消息 - ChPing = WsChannel("ping") - ChChat = WsChannel("chat") - ChMj = WsChannel("mj") - ChSd = WsChannel("sd") - ChDall = WsChannel("dall") - ChSuno = WsChannel("suno") - ChLuma = WsChannel("luma") + ChPing = WsChannel("ping") + ChChat = WsChannel("chat") + ChMj = WsChannel("mj") + ChSd = WsChannel("sd") + ChDall = WsChannel("dall") + ChSuno = WsChannel("suno") + ChLuma = WsChannel("luma") + ChKeLing = WsChannel("keling") ) // InputMessage 对话输入消息结构 diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index e42cd9ca..1fa8cf86 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -74,7 +74,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { } userId := int(h.GetLoginUserId(c)) - params := types.VideoParams{ + params := types.LumaVideoParams{ PromptOptimize: data.ExpandPrompt, Loop: data.Loop, StartImgURL: data.FirstFrameImg, @@ -119,6 +119,98 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { resp.SUCCESS(c) } +func (h *VideoHandler) KeLingCreate(c *gin.Context) { + + var data struct { + Channel string `json:"channel"` + ClientId string `json:"client_id"` + TaskType string `json:"task_type"` // 任务类型: text2video/image2video + Model string `json:"model"` // 模型: default/anime + Prompt string `json:"prompt"` // 视频描述 + NegPrompt string `json:"negative_prompt"` // 负面提示词 + CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1) + Mode string `json:"mode"` // 生成模式: std/pro + AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1 + Duration string `json:"duration"` // 视频时长: 5/10 + CameraControl types.CameraControl `json:"camera_control"` // 摄像机控制 + Image string `json:"image"` // 参考图片URL(image2video) + ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video) + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + if user.Power < h.App.SysConfig.LumaPower { + resp.ERROR(c, "您的算力不足,请充值后再试!") + return + } + + if data.Prompt == "" { + resp.ERROR(c, "prompt is needed") + return + } + + userId := int(h.GetLoginUserId(c)) + params := types.KeLingVideoParams{ + TaskType: data.TaskType, + Model: data.Model, + Prompt: data.Prompt, + NegPrompt: data.NegPrompt, + CfgScale: data.CfgScale, + Mode: data.Mode, + AspectRatio: data.AspectRatio, + Duration: data.Duration, + CameraControl: data.CameraControl, + Image: data.Image, + ImageTail: data.ImageTail, + } + task := types.VideoTask{ + ClientId: data.ClientId, + UserId: userId, + Type: types.VideoKeLing, + Prompt: data.Prompt, + Params: params, + TranslateModelId: h.App.SysConfig.TranslateModelId, + Channel: data.Channel, + } + // 插入数据库 + job := model.VideoJob{ + UserId: userId, + Type: types.VideoKeLing, + Prompt: data.Prompt, + Power: h.App.SysConfig.LumaPower, + TaskInfo: utils.JsonEncode(task), + } + tx := h.DB.Create(&job) + if tx.Error != nil { + resp.ERROR(c, tx.Error.Error()) + return + } + + // 创建任务 + task.Id = job.Id + h.videoService.PushTask(task) + + // update user's power + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "keling", + Remark: fmt.Sprintf("keling 文生视频,任务ID:%d", job.Id), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + resp.SUCCESS(c) +} + func (h *VideoHandler) List(c *gin.Context) { userId := h.GetLoginUserId(c) t := c.Query("type") diff --git a/api/main.go b/api/main.go index 234fba2a..acfc2cf9 100644 --- a/api/main.go +++ b/api/main.go @@ -492,6 +492,7 @@ func main() { fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) { group := s.Engine.Group("/api/video") group.POST("luma/create", h.LumaCreate) + group.POST("keling/create", h.KeLingCreate) group.GET("list", h.List) group.GET("remove", h.Remove) group.GET("publish", h.Publish) diff --git a/api/service/types.go b/api/service/types.go index a2b69e36..e8b48203 100644 --- a/api/service/types.go +++ b/api/service/types.go @@ -12,6 +12,7 @@ type NotifyMessage struct { ClientId string `json:"client_id"` JobId int `json:"job_id"` Message string `json:"message"` + Type string `json:"type"` } const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/api/service/video/luma.go b/api/service/video/luma.go deleted file mode 100644 index e40b9b47..00000000 --- a/api/service/video/luma.go +++ /dev/null @@ -1,377 +0,0 @@ -package video - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "encoding/json" - "errors" - "fmt" - "geekai/core/types" - logger2 "geekai/logger" - "geekai/service" - "geekai/service/oss" - "geekai/store" - "geekai/store/model" - "geekai/utils" - "github.com/go-redis/redis/v8" - "io" - "time" - - "github.com/imroc/req/v3" - "gorm.io/gorm" -) - -var logger = logger2.GetLogger() - -type Service struct { - httpClient *req.Client - db *gorm.DB - uploadManager *oss.UploaderManager - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - wsService *service.WebsocketService - clientIds map[uint]string - userService *service.UserService -} - -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service { - return &Service{ - httpClient: req.C().SetTimeout(time.Minute * 3), - db: db, - taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), - notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), - wsService: wsService, - uploadManager: manager, - clientIds: map[uint]string{}, - userService: userService, - } -} - -func (s *Service) PushTask(task types.VideoTask) { - logger.Infof("add a new Video task to the task list: %+v", task) - s.taskQueue.RPush(task) -} - -func (s *Service) Run() { - // 将数据库中未提交的人物加载到队列 - var jobs []model.VideoJob - s.db.Where("task_id", "").Where("progress", 0).Find(&jobs) - for _, v := range jobs { - var task types.VideoTask - err := utils.JsonDecode(v.TaskInfo, &task) - if err != nil { - logger.Errorf("decode task info with error: %v", err) - continue - } - task.Id = v.Id - s.PushTask(task) - s.clientIds[v.Id] = task.ClientId - } - logger.Info("Starting Video job consumer...") - go func() { - for { - var task types.VideoTask - err := s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } - - // translate prompt - if utils.HasChinese(task.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId) - if err == nil { - task.Prompt = content - } else { - logger.Warnf("error with translate prompt: %v", err) - } - } - - if task.ClientId != "" { - s.clientIds[task.Id] = task.ClientId - } - - var r LumaRespVo - r, err = s.LumaCreate(task) - if err != nil { - logger.Errorf("create task with error: %v", err) - err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ - "err_msg": err.Error(), - "progress": service.FailTaskProgress, - "cover_url": "/images/failed.jpg", - }).Error - if err != nil { - logger.Errorf("update task with error: %v", err) - } - s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) - continue - } - - // 更新任务信息 - err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ - "task_id": r.Id, - "channel": r.Channel, - "prompt_ext": r.Prompt, - }).Error - if err != nil { - logger.Errorf("update task with error: %v", err) - s.PushTask(task) - } - } - }() -} - -type LumaRespVo struct { - Id string `json:"id"` - Prompt string `json:"prompt"` - State string `json:"state"` - QueueState interface{} `json:"queue_state"` - CreatedAt string `json:"created_at"` - Video interface{} `json:"video"` - VideoRaw interface{} `json:"video_raw"` - Liked interface{} `json:"liked"` - EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` - Thumbnail interface{} `json:"thumbnail"` - Channel string `json:"channel,omitempty"` -} - -func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) { - // 读取 API KEY - var apiKey model.ApiKey - session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true) - if task.Channel != "" { - session = session.Where("api_url", task.Channel) - } - tx := session.Order("last_used_at DESC").First(&apiKey) - if tx.Error != nil { - return LumaRespVo{}, errors.New("no available API KEY for Luma") - } - - reqBody := map[string]interface{}{ - "user_prompt": task.Prompt, - "expand_prompt": task.Params.PromptOptimize, - "loop": task.Params.Loop, - "image_url": task.Params.StartImgURL, - "image_end_url": task.Params.EndImgURL, - } - var res LumaRespVo - apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL) - logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) - r, err := req.C().R(). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(reqBody). - Post(apiURL) - if err != nil { - return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err) - } - - if r.StatusCode != 200 && r.StatusCode != 201 { - return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) - } - - body, _ := io.ReadAll(r.Body) - err = json.Unmarshal(body, &res) - if err != nil { - return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) - } - - // update the last_use_at for api key - apiKey.LastUsedAt = time.Now().Unix() - session.Updates(&apiKey) - res.Channel = apiKey.ApiURL - return res, nil -} - -func (s *Service) CheckTaskNotify() { - go func() { - logger.Info("Running Suno task notify checking ...") - for { - var message service.NotifyMessage - err := s.notifyQueue.LPop(&message) - if err != nil { - continue - } - logger.Debugf("Receive notify message: %+v", message) - client := s.wsService.Clients.Get(message.ClientId) - if client == nil { - continue - } - utils.SendChannelMsg(client, types.ChLuma, message.Message) - } - }() -} - -func (s *Service) DownloadFiles() { - go func() { - var items []model.VideoJob - for { - res := s.db.Where("progress", 102).Find(&items) - if res.Error != nil { - continue - } - - for _, v := range items { - if v.WaterURL == "" { - continue - } - - logger.Infof("try download video: %s", v.WaterURL) - videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true) - if err != nil { - logger.Errorf("download video with error: %v", err) - continue - } - logger.Infof("download video success: %s", videoURL) - v.WaterURL = videoURL - - if v.VideoURL != "" { - logger.Infof("try download no water video: %s", v.VideoURL) - videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true) - if err != nil { - logger.Errorf("download video with error: %v", err) - continue - } - } - logger.Infof("download no water video success: %s", videoURL) - v.VideoURL = videoURL - v.Progress = 100 - s.db.Updates(&v) - s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) - } - - time.Sleep(time.Second * 10) - } - }() -} - -// SyncTaskProgress 异步拉取任务 -func (s *Service) SyncTaskProgress() { - go func() { - var jobs []model.VideoJob - for { - res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) - if res.Error != nil { - continue - } - - for _, job := range jobs { - task, err := s.QueryLumaTask(job.TaskId, job.Channel) - if err != nil { - logger.Errorf("query task with error: %v", err) - // 更新任务信息 - s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ - "progress": service.FailTaskProgress, // 102 表示资源未下载完成, - "err_msg": err.Error(), - }) - continue - } - - logger.Debugf("task: %+v", task) - if task.State == "completed" { // 更新任务信息 - data := map[string]interface{}{ - "progress": 102, // 102 表示资源未下载完成, - "water_url": task.Video.Url, - "raw_data": utils.JsonEncode(task), - "prompt_ext": task.Prompt, - "cover_url": task.Thumbnail.Url, - } - if task.Video.DownloadUrl != "" { - data["video_url"] = task.Video.DownloadUrl - } - err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error - if err != nil { - logger.Errorf("更新数据库失败:%v", err) - continue - } - } - - } - - // 找出失败的任务,并恢复其扣减算力 - s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) - for _, job := range jobs { - err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "luma", - Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - }) - if err != nil { - continue - } - // 更新任务状态 - s.db.Model(&job).UpdateColumn("power", 0) - } - time.Sleep(time.Second * 10) - } - }() -} - -type LumaTaskVo struct { - Id string `json:"id"` - Liked interface{} `json:"liked"` - State string `json:"state"` - Video struct { - Url string `json:"url"` - Width int `json:"width"` - Height int `json:"height"` - Thumbnail string `json:"thumbnail"` - DownloadUrl string `json:"download_url"` - } `json:"video"` - Prompt string `json:"prompt"` - UserId string `json:"user_id"` - BatchId string `json:"batch_id"` - Thumbnail struct { - Url string `json:"url"` - Width int `json:"width"` - Height int `json:"height"` - } `json:"thumbnail"` - VideoRaw struct { - Url string `json:"url"` - Width int `json:"width"` - Height int `json:"height"` - } `json:"video_raw"` - CreatedAt string `json:"created_at"` - LastFrame struct { - Url string `json:"url"` - Width int `json:"width"` - Height int `json:"height"` - } `json:"last_frame"` -} - -func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) { - // 读取 API KEY - var apiKey model.ApiKey - err := s.db.Session(&gorm.Session{}).Where("type", "luma"). - Where("api_url", channel). - Where("enabled", true). - Order("last_used_at DESC").First(&apiKey).Error - if err != nil { - return LumaTaskVo{}, errors.New("no available API KEY for Luma") - } - - apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId) - var res LumaTaskVo - r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) - - if err != nil { - return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err) - } - defer r.Body.Close() - - if r.StatusCode != 200 { - return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String()) - } - - body, _ := io.ReadAll(r.Body) - err = json.Unmarshal(body, &res) - if err != nil { - return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) - } - - return res, nil -} diff --git a/api/service/video/video.go b/api/service/video/video.go new file mode 100644 index 00000000..e37bc085 --- /dev/null +++ b/api/service/video/video.go @@ -0,0 +1,674 @@ +package video + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "github.com/go-redis/redis/v8" + "io" + "io/ioutil" + "net/http" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + wsService *service.WebsocketService + clientIds map[uint]string + userService *service.UserService +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), + wsService: wsService, + uploadManager: manager, + clientIds: map[uint]string{}, + userService: userService, + } +} + +func (s *Service) PushTask(task types.VideoTask) { + logger.Infof("add a new Video task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + // 将数据库中未提交的人物加载到队列 + var jobs []model.VideoJob + s.db.Where("task_id", "").Where("progress", 0).Find(&jobs) + for _, v := range jobs { + var task types.VideoTask + err := utils.JsonDecode(v.TaskInfo, &task) + if err != nil { + logger.Errorf("decode task info with error: %v", err) + continue + } + task.Id = v.Id + s.PushTask(task) + s.clientIds[v.Id] = task.ClientId + } + logger.Info("Starting Video job consumer...") + go func() { + for { + var task types.VideoTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + if task.ClientId != "" { + s.clientIds[task.Id] = task.ClientId + } + + if task.Type == types.VideoLuma { + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId) + if err == nil { + task.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + var r LumaRespVo + r, err = s.LumaCreate(task) + if err != nil { + logger.Errorf("create task with error: %v", err) + err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "err_msg": err.Error(), + "progress": service.FailTaskProgress, + "cover_url": "/images/failed.jpg", + }).Error + if err != nil { + logger.Errorf("update task with error: %v", err) + } + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoLuma}) + continue + } + + // 更新任务信息 + err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "task_id": r.Id, + "channel": r.Channel, + "prompt_ext": r.Prompt, + }).Error + if err != nil { + logger.Errorf("update task with error: %v", err) + s.PushTask(task) + } + } else if task.Type == types.VideoKeLing { + var r KeLingRespVo + r, err = s.KeLingCreate(task) + if err != nil { + logger.Errorf("create task with error: %v", err) + err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "err_msg": r.Message, + "progress": service.FailTaskProgress, + "cover_url": "/images/failed.jpg", + }).Error + if err != nil { + logger.Errorf("update task with error: %v", err) + } + s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoKeLing}) + continue + } + + // 更新任务信息 + err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "task_id": r.Data.TaskID, + "channel": task.Channel, + "prompt_ext": task.Prompt, + }).Error + if err != nil { + logger.Errorf("update task with error: %v", err) + s.PushTask(task) + } + } + + } + }() +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running Suno task notify checking ...") + for { + var message service.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + logger.Debugf("Receive notify message: %+v", message) + client := s.wsService.Clients.Get(message.ClientId) + if client == nil { + continue + } + utils.SendChannelMsg(client, types.ChLuma, message.Message) + } + }() +} + +func (s *Service) DownloadFiles() { + go func() { + var items []model.VideoJob + for { + res := s.db.Where("progress", 102).Find(&items) + if res.Error != nil { + continue + } + + for _, v := range items { + if v.WaterURL == "" { + continue + } + + logger.Infof("try download video: %s", v.WaterURL) + videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true) + if err != nil { + logger.Errorf("download video with error: %v", err) + continue + } + logger.Infof("download video success: %s", videoURL) + v.WaterURL = videoURL + + if v.VideoURL != "" { + logger.Infof("try download no water video: %s", v.VideoURL) + videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true) + if err != nil { + logger.Errorf("download video with error: %v", err) + continue + } + } + logger.Infof("download no water video success: %s", videoURL) + v.VideoURL = videoURL + v.Progress = 100 + s.db.Updates(&v) + + // Convert TaskInfo to VideoTask + var videoTask types.VideoTask + if err := json.Unmarshal([]byte(v.TaskInfo), &videoTask); err != nil { + logger.Errorf("failed to unmarshal task info to VideoTask: %v", err) + continue + } + + s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished, Type: videoTask.Type}) + } + + time.Sleep(time.Second * 10) + } + }() +} + +// SyncTaskProgress 异步拉取任务 +func (s *Service) SyncTaskProgress() { + go func() { + var jobs []model.VideoJob + for { + res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) + if res.Error != nil { + continue + } + + for _, job := range jobs { + if job.Type == types.VideoLuma { + task, err := s.QueryLumaTask(job.TaskId, job.Channel) + if err != nil { + logger.Errorf("query task with error: %v", err) + // 更新任务信息 + s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": service.FailTaskProgress, // 102 表示资源未下载完成, + "err_msg": err.Error(), + }) + continue + } + + logger.Debugf("task: %+v", task) + if task.State == "completed" { // 更新任务信息 + data := map[string]interface{}{ + "progress": 102, // 102 表示资源未下载完成, + "water_url": task.Video.Url, + "raw_data": utils.JsonEncode(task), + "prompt_ext": task.Prompt, + "cover_url": task.Thumbnail.Url, + } + if task.Video.DownloadUrl != "" { + data["video_url"] = task.Video.DownloadUrl + } + err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error + if err != nil { + logger.Errorf("更新数据库失败:%v", err) + continue + } + } + } else if job.Type == types.VideoKeLing { + // Convert TaskInfo to VideoTask + var videoTask types.VideoTask + if err := json.Unmarshal([]byte(job.TaskInfo), &videoTask); err != nil { + logger.Errorf("failed to unmarshal task info to VideoTask: %v", err) + continue + } + + // Type assert task.Params to KeLingVideoParams + paramsMap, ok := videoTask.Params.(map[string]interface{}) + if !ok { + continue + } + + // Convert map to KeLingVideoParams + paramsBytes, err := json.Marshal(paramsMap) + if err != nil { + continue + } + + var params types.KeLingVideoParams + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + continue + } + + task, err := s.QueryKeLingTask(job.TaskId, job.Channel, params.TaskType) + if err != nil { + logger.Errorf("query task with error: %v", err) + // 更新任务信息 + s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": service.FailTaskProgress, // 102 表示资源未下载完成, + "err_msg": err.Error(), + }) + continue + } + + logger.Debugf("task: %+v", task) + if task.TaskStatus == "succeed" { // 更新任务信息 + data := map[string]interface{}{ + "progress": 102, // 102 表示资源未下载完成, + "water_url": task.TaskResult.Videos[0].URL, + "raw_data": utils.JsonEncode(task), + "prompt_ext": job.Prompt, + "cover_url": "", + } + if len(task.TaskResult.Videos) > 0 { + data["video_url"] = task.TaskResult.Videos[0].URL + } + err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error + if err != nil { + logger.Errorf("更新数据库失败:%v", err) + continue + } + } + } + + } + + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: job.Type, + Remark: fmt.Sprintf("%s 任务失败,退回算力。任务ID:%s,Err:%s", job.Type, job.TaskId, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } + time.Sleep(time.Second * 10) + } + }() +} + +type LumaTaskVo struct { + Id string `json:"id"` + Liked interface{} `json:"liked"` + State string `json:"state"` + Video struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + Thumbnail string `json:"thumbnail"` + DownloadUrl string `json:"download_url"` + } `json:"video"` + Prompt string `json:"prompt"` + UserId string `json:"user_id"` + BatchId string `json:"batch_id"` + Thumbnail struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + } `json:"thumbnail"` + VideoRaw struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + } `json:"video_raw"` + CreatedAt string `json:"created_at"` + LastFrame struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + } `json:"last_frame"` +} + +type LumaRespVo struct { + Id string `json:"id"` + Prompt string `json:"prompt"` + State string `json:"state"` + QueueState interface{} `json:"queue_state"` + CreatedAt string `json:"created_at"` + Video interface{} `json:"video"` + VideoRaw interface{} `json:"video_raw"` + Liked interface{} `json:"liked"` + EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` + Thumbnail interface{} `json:"thumbnail"` + Channel string `json:"channel,omitempty"` +} + +func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true) + if task.Channel != "" { + session = session.Where("api_url", task.Channel) + } + tx := session.Order("last_used_at DESC").First(&apiKey) + if tx.Error != nil { + return LumaRespVo{}, errors.New("no available API KEY for Luma") + } + + // Type assert task.Params to LumaVideoParams + paramsMap, ok := task.Params.(map[string]interface{}) + if !ok { + return LumaRespVo{}, errors.New("invalid params type for Luma video task") + } + + // Convert map to LumaVideoParams + paramsBytes, err := json.Marshal(paramsMap) + if err != nil { + return LumaRespVo{}, fmt.Errorf("failed to marshal params: %v", err) + } + + var params types.LumaVideoParams + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return LumaRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err) + } + + reqBody := map[string]interface{}{ + "user_prompt": task.Prompt, + "expand_prompt": params.PromptOptimize, + "loop": params.Loop, + "image_url": params.StartImgURL, + "image_end_url": params.EndImgURL, + } + var res LumaRespVo + apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL) + logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). + Post(apiURL) + if err != nil { + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.StatusCode != 200 && r.StatusCode != 201 { + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) + } + + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + // update the last_use_at for api key + apiKey.LastUsedAt = time.Now().Unix() + session.Updates(&apiKey) + res.Channel = apiKey.ApiURL + return res, nil +} + +func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + err := s.db.Session(&gorm.Session{}).Where("type", "luma"). + Where("api_url", channel). + Where("enabled", true). + Order("last_used_at DESC").First(&apiKey).Error + if err != nil { + return LumaTaskVo{}, errors.New("no available API KEY for Luma") + } + + apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId) + var res LumaTaskVo + r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) + + if err != nil { + return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err) + } + defer r.Body.Close() + + if r.StatusCode != 200 { + return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String()) + } + + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + return res, nil +} + +type KeLingRespVo struct { + Code int `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` + Data struct { + TaskID string `json:"task_id"` + TaskStatus string `json:"task_status"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + } `json:"data"` +} + +func (s *Service) KeLingCreate(task types.VideoTask) (KeLingRespVo, error) { + var apiKey model.ApiKey + session := s.db.Session(&gorm.Session{}).Where("type", "keling").Where("enabled", true) + if task.Channel != "" { + session = session.Where("api_url", task.Channel) + } + tx := session.Order("last_used_at DESC").First(&apiKey) + if tx.Error != nil { + return KeLingRespVo{}, errors.New("no available API KEY for keling") + } + + // Type assert task.Params to KeLingVideoParams + paramsMap, ok := task.Params.(map[string]interface{}) + if !ok { + return KeLingRespVo{}, errors.New("invalid params type for KeLing video task") + } + + // Convert map to KeLingVideoParams + paramsBytes, err := json.Marshal(paramsMap) + if err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to marshal params: %v", err) + } + + var params types.KeLingVideoParams + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err) + } + + // 2. 构建API请求参数 + payload := map[string]interface{}{ + "model": params.Model, + "prompt": task.Prompt, + "negative_prompt": params.NegPrompt, + "cfg_scale": params.CfgScale, + "mode": params.Mode, + "aspect_ratio": params.AspectRatio, + "duration": params.Duration, + } + + // 只有当 CameraControl 的类型不为空时,才处理摄像机控制参数 + if params.CameraControl.Type != "" { + cameraControl := map[string]interface{}{ + "type": params.CameraControl.Type, + } + + // 只有在 simple 类型时才添加 config 参数 + if params.CameraControl.Type == "simple" { + cameraControl["config"] = params.CameraControl.Config + } + + payload["camera_control"] = cameraControl + } + + jsonPayload, err := json.Marshal(payload) + if err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to marshal payload: %v", err) + } + + // 3. 准备HTTP请求 + url := fmt.Sprintf("%s/kling/v1/videos/%s", apiKey.ApiURL, params.TaskType) + req, err := http.NewRequest("POST", url, bytes.NewReader(jsonPayload)) + if err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to create request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+apiKey.Value) + req.Header.Set("Content-Type", "application/json") + + // 4. 发送请求 + client := &http.Client{Timeout: time.Duration(30) * time.Second} + resp, err := client.Do(req) + if err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to send request: %v", err) + } + defer resp.Body.Close() + + // 5. 处理响应 + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to read response: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return KeLingRespVo{}, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var apiResponse = KeLingRespVo{} + if err := json.Unmarshal(body, &apiResponse); err != nil { + return KeLingRespVo{}, fmt.Errorf("failed to parse response: %v", err) + } + + return apiResponse, nil +} + +// VideoCallbackData 表示视频生成任务的回调数据 +type VideoCallbackData struct { + TaskID string `json:"task_id"` + TaskStatus string `json:"task_status"` + TaskStatusMsg string `json:"task_status_msg"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + TaskResult TaskResult `json:"task_result"` +} + +type TaskResult struct { + Images []CallBackImageResult `json:"images,omitempty"` + Videos []CallBackVideoResult `json:"videos,omitempty"` +} + +type CallBackImageResult struct { + Index int `json:"index"` + URL string `json:"url"` +} + +type CallBackVideoResult struct { + ID string `json:"id"` + URL string `json:"url"` + Duration string `json:"duration"` +} + +func (s *Service) QueryKeLingTask(taskId string, channel string, action string) (VideoCallbackData, error) { + var apiKey model.ApiKey + err := s.db.Session(&gorm.Session{}).Where("type", "keling"). + Where("api_url", channel). + Where("enabled", true). + Order("last_used_at DESC").First(&apiKey).Error + if err != nil { + return VideoCallbackData{}, errors.New("no available API KEY for keling") + } + + url := fmt.Sprintf("%s/kling/v1/videos/%s/%s", apiKey.ApiURL, action, taskId) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return VideoCallbackData{}, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+apiKey.Value) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + res, err := client.Do(req) + if err != nil { + return VideoCallbackData{}, fmt.Errorf("failed to execute request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err) + } + + var response struct { + Code int `json:"code"` + Message string `json:"message"` + Data VideoCallbackData `json:"data"` + } + + if err := json.Unmarshal(body, &response); err != nil { + return VideoCallbackData{}, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if response.Code != 0 { + return VideoCallbackData{}, fmt.Errorf("API error: %s", response.Message) + } + + return response.Data, nil +} diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/config/config.yaml @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index 5bb4dafb..756ba091 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -140,6 +140,7 @@ const types = ref([ { label: "DALL-E", value: "dalle" }, { label: "Suno文生歌", value: "suno" }, { label: "Luma视频", value: "luma" }, + { label: "可灵视频", value: "keling" }, { label: "Realtime API", value: "realtime" }, { label: "其他", value: "other" }, ]);