feat: 增加 可灵功能

This commit is contained in:
mario
2025-02-14 15:03:29 +08:00
parent dd675c9a9b
commit d124eddd9d
10 changed files with 825 additions and 399 deletions

View File

@@ -150,6 +150,7 @@ type SystemConfig struct {
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
KeLingPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力

View File

@@ -73,18 +73,18 @@ type SdTaskParams struct {
// DallTask DALL-E task
type DallTask struct {
ClientId string `json:"client_id"`
ModelId uint `json:"model_id"`
ModelName string `json:"model_name"`
Id uint `json:"id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
ClientId string `json:"client_id"`
ModelId uint `json:"model_id"`
ModelName string `json:"model_name"`
Id uint `json:"id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
}
type SunoTask struct {
@@ -109,6 +109,7 @@ const (
VideoLuma = "luma"
VideoRunway = "runway"
VideoCog = "cog"
VideoKeLing = "keling"
)
type VideoTask struct {
@@ -119,11 +120,11 @@ type VideoTask struct {
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
Params interface{} `json:"params"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
}
type VideoParams struct {
type LumaVideoParams struct {
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
Loop bool `json:"loop"` // 是否循环参考图
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
@@ -133,3 +134,33 @@ type VideoParams struct {
Style string `json:"style"` // 风格
Duration int `json:"duration"` // 视频时长(秒)
}
type KeLingVideoParams struct {
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
Model string `json:"model"` // 模型: default/anime
Prompt string `json:"prompt"` // 视频描述
NegPrompt string `json:"negative_prompt"` // 负面提示词
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
Mode string `json:"mode"` // 生成模式: std/pro
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
Duration string `json:"duration"` // 视频时长: 5/10
CameraControl CameraControl `json:"camera_control"` // 摄像机控制
Image string `json:"image"` // 参考图片URL(image2video)
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
}
// CameraControl 摄像机控制
type CameraControl struct {
Type string `json:"type"` // 控制类型: simple/down_back/forward_up/right_turn_forward/left_turn_forward
Config CameraConfig `json:"config"` // 控制参数(仅simple类型时使用)
}
// CameraConfig 摄像机参数
type CameraConfig struct {
Horizontal int `json:"horizontal"` // 水平移动(-10到10)
Vertical int `json:"vertical"` // 垂直移动(-10到10)
Pan int `json:"pan"` // 左右旋转(-10到10)
Tilt int `json:"tilt"` // 上下旋转(-10到10)
Roll int `json:"roll"` // 横向翻转(-10到10)
Zoom int `json:"zoom"` // 镜头缩放(-10到10)
}

View File

@@ -34,13 +34,14 @@ const (
MsgTypeErr = WsMsgType("error")
MsgTypePing = WsMsgType("ping") // 心跳消息
ChPing = WsChannel("ping")
ChChat = WsChannel("chat")
ChMj = WsChannel("mj")
ChSd = WsChannel("sd")
ChDall = WsChannel("dall")
ChSuno = WsChannel("suno")
ChLuma = WsChannel("luma")
ChPing = WsChannel("ping")
ChChat = WsChannel("chat")
ChMj = WsChannel("mj")
ChSd = WsChannel("sd")
ChDall = WsChannel("dall")
ChSuno = WsChannel("suno")
ChLuma = WsChannel("luma")
ChKeLing = WsChannel("keling")
)
// InputMessage 对话输入消息结构

View File

@@ -74,7 +74,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
params := types.LumaVideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
@@ -119,6 +119,98 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
resp.SUCCESS(c)
}
func (h *VideoHandler) KeLingCreate(c *gin.Context) {
var data struct {
Channel string `json:"channel"`
ClientId string `json:"client_id"`
TaskType string `json:"task_type"` // 任务类型: text2video/image2video
Model string `json:"model"` // 模型: default/anime
Prompt string `json:"prompt"` // 视频描述
NegPrompt string `json:"negative_prompt"` // 负面提示词
CfgScale float64 `json:"cfg_scale"` // 相关性系数(0-1)
Mode string `json:"mode"` // 生成模式: std/pro
AspectRatio string `json:"aspect_ratio"` // 画面比例: 16:9/9:16/1:1
Duration string `json:"duration"` // 视频时长: 5/10
CameraControl types.CameraControl `json:"camera_control"` // 摄像机控制
Image string `json:"image"` // 参考图片URL(image2video)
ImageTail string `json:"image_tail"` // 尾帧图片URL(image2video)
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.KeLingVideoParams{
TaskType: data.TaskType,
Model: data.Model,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
CfgScale: data.CfgScale,
Mode: data.Mode,
AspectRatio: data.AspectRatio,
Duration: data.Duration,
CameraControl: data.CameraControl,
Image: data.Image,
ImageTail: data.ImageTail,
}
task := types.VideoTask{
ClientId: data.ClientId,
UserId: userId,
Type: types.VideoKeLing,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.TranslateModelId,
Channel: data.Channel,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoKeLing,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
TaskInfo: utils.JsonEncode(task),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
task.Id = job.Id
h.videoService.PushTask(task)
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "keling",
Remark: fmt.Sprintf("keling 文生视频任务ID%d", job.Id),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")

View File

@@ -492,6 +492,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.POST("luma/create", h.LumaCreate)
group.POST("keling/create", h.KeLingCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)

View File

@@ -12,6 +12,7 @@ type NotifyMessage struct {
ClientId string `json:"client_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
Type string `json:"type"`
}
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@@ -1,377 +0,0 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService
clientIds map[uint]string
userService *service.UserService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
wsService: wsService,
uploadManager: manager,
clientIds: map[uint]string{},
userService: userService,
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.VideoTask
err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil {
logger.Errorf("decode task info with error: %v", err)
continue
}
task.Id = v.Id
s.PushTask(task)
s.clientIds[v.Id] = task.ClientId
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}()
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
QueueState interface{} `json:"queue_state"`
CreatedAt string `json:"created_at"`
Video interface{} `json:"video"`
VideoRaw interface{} `json:"video_raw"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Thumbnail interface{} `json:"thumbnail"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": task.Params.PromptOptimize,
"loop": task.Params.Loop,
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Infof("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
"cover_url": task.Thumbnail.Url,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
Thumbnail string `json:"thumbnail"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
UserId string `json:"user_id"`
BatchId string `json:"batch_id"`
Thumbnail struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"thumbnail"`
VideoRaw struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"video_raw"`
CreatedAt string `json:"created_at"`
LastFrame struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"last_frame"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

674
api/service/video/video.go Normal file
View File

@@ -0,0 +1,674 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"io/ioutil"
"net/http"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService
clientIds map[uint]string
userService *service.UserService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
wsService: wsService,
uploadManager: manager,
clientIds: map[uint]string{},
userService: userService,
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.VideoTask
err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil {
logger.Errorf("decode task info with error: %v", err)
continue
}
task.Id = v.Id
s.PushTask(task)
s.clientIds[v.Id] = task.ClientId
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
if task.Type == types.VideoLuma {
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoLuma})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
} else if task.Type == types.VideoKeLing {
var r KeLingRespVo
r, err = s.KeLingCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": r.Message,
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoKeLing})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Data.TaskID,
"channel": task.Channel,
"prompt_ext": task.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}
}()
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Infof("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
// Convert TaskInfo to VideoTask
var videoTask types.VideoTask
if err := json.Unmarshal([]byte(v.TaskInfo), &videoTask); err != nil {
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
continue
}
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished, Type: videoTask.Type})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
if job.Type == types.VideoLuma {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
"cover_url": task.Thumbnail.Url,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
} else if job.Type == types.VideoKeLing {
// Convert TaskInfo to VideoTask
var videoTask types.VideoTask
if err := json.Unmarshal([]byte(job.TaskInfo), &videoTask); err != nil {
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
continue
}
// Type assert task.Params to KeLingVideoParams
paramsMap, ok := videoTask.Params.(map[string]interface{})
if !ok {
continue
}
// Convert map to KeLingVideoParams
paramsBytes, err := json.Marshal(paramsMap)
if err != nil {
continue
}
var params types.KeLingVideoParams
if err := json.Unmarshal(paramsBytes, &params); err != nil {
continue
}
task, err := s.QueryKeLingTask(job.TaskId, job.Channel, params.TaskType)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.TaskStatus == "succeed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.TaskResult.Videos[0].URL,
"raw_data": utils.JsonEncode(task),
"prompt_ext": job.Prompt,
"cover_url": "",
}
if len(task.TaskResult.Videos) > 0 {
data["video_url"] = task.TaskResult.Videos[0].URL
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
}
// 找出失败的任务,并恢复其扣减算力
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.Type,
Remark: fmt.Sprintf("%s 任务失败退回算力。任务ID%sErr:%s", job.Type, job.TaskId, job.ErrMsg),
})
if err != nil {
continue
}
// 更新任务状态
s.db.Model(&job).UpdateColumn("power", 0)
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
Thumbnail string `json:"thumbnail"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
UserId string `json:"user_id"`
BatchId string `json:"batch_id"`
Thumbnail struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"thumbnail"`
VideoRaw struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"video_raw"`
CreatedAt string `json:"created_at"`
LastFrame struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"last_frame"`
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
QueueState interface{} `json:"queue_state"`
CreatedAt string `json:"created_at"`
Video interface{} `json:"video"`
VideoRaw interface{} `json:"video_raw"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Thumbnail interface{} `json:"thumbnail"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
// Type assert task.Params to LumaVideoParams
paramsMap, ok := task.Params.(map[string]interface{})
if !ok {
return LumaRespVo{}, errors.New("invalid params type for Luma video task")
}
// Convert map to LumaVideoParams
paramsBytes, err := json.Marshal(paramsMap)
if err != nil {
return LumaRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
}
var params types.LumaVideoParams
if err := json.Unmarshal(paramsBytes, &params); err != nil {
return LumaRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": params.PromptOptimize,
"loop": params.Loop,
"image_url": params.StartImgURL,
"image_end_url": params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}
type KeLingRespVo struct {
Code int `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
Data struct {
TaskID string `json:"task_id"`
TaskStatus string `json:"task_status"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
} `json:"data"`
}
func (s *Service) KeLingCreate(task types.VideoTask) (KeLingRespVo, error) {
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "keling").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return KeLingRespVo{}, errors.New("no available API KEY for keling")
}
// Type assert task.Params to KeLingVideoParams
paramsMap, ok := task.Params.(map[string]interface{})
if !ok {
return KeLingRespVo{}, errors.New("invalid params type for KeLing video task")
}
// Convert map to KeLingVideoParams
paramsBytes, err := json.Marshal(paramsMap)
if err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
}
var params types.KeLingVideoParams
if err := json.Unmarshal(paramsBytes, &params); err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
}
// 2. 构建API请求参数
payload := map[string]interface{}{
"model": params.Model,
"prompt": task.Prompt,
"negative_prompt": params.NegPrompt,
"cfg_scale": params.CfgScale,
"mode": params.Mode,
"aspect_ratio": params.AspectRatio,
"duration": params.Duration,
}
// 只有当 CameraControl 的类型不为空时,才处理摄像机控制参数
if params.CameraControl.Type != "" {
cameraControl := map[string]interface{}{
"type": params.CameraControl.Type,
}
// 只有在 simple 类型时才添加 config 参数
if params.CameraControl.Type == "simple" {
cameraControl["config"] = params.CameraControl.Config
}
payload["camera_control"] = cameraControl
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to marshal payload: %v", err)
}
// 3. 准备HTTP请求
url := fmt.Sprintf("%s/kling/v1/videos/%s", apiKey.ApiURL, params.TaskType)
req, err := http.NewRequest("POST", url, bytes.NewReader(jsonPayload))
if err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
req.Header.Set("Content-Type", "application/json")
// 4. 发送请求
client := &http.Client{Timeout: time.Duration(30) * time.Second}
resp, err := client.Do(req)
if err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
// 5. 处理响应
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to read response: %v", err)
}
if resp.StatusCode != http.StatusOK {
return KeLingRespVo{}, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var apiResponse = KeLingRespVo{}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return KeLingRespVo{}, fmt.Errorf("failed to parse response: %v", err)
}
return apiResponse, nil
}
// VideoCallbackData 表示视频生成任务的回调数据
type VideoCallbackData struct {
TaskID string `json:"task_id"`
TaskStatus string `json:"task_status"`
TaskStatusMsg string `json:"task_status_msg"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
TaskResult TaskResult `json:"task_result"`
}
type TaskResult struct {
Images []CallBackImageResult `json:"images,omitempty"`
Videos []CallBackVideoResult `json:"videos,omitempty"`
}
type CallBackImageResult struct {
Index int `json:"index"`
URL string `json:"url"`
}
type CallBackVideoResult struct {
ID string `json:"id"`
URL string `json:"url"`
Duration string `json:"duration"`
}
func (s *Service) QueryKeLingTask(taskId string, channel string, action string) (VideoCallbackData, error) {
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "keling").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return VideoCallbackData{}, errors.New("no available API KEY for keling")
}
url := fmt.Sprintf("%s/kling/v1/videos/%s/%s", apiKey.ApiURL, action, taskId)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return VideoCallbackData{}, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
res, err := client.Do(req)
if err != nil {
return VideoCallbackData{}, fmt.Errorf("failed to execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err)
}
var response struct {
Code int `json:"code"`
Message string `json:"message"`
Data VideoCallbackData `json:"data"`
}
if err := json.Unmarshal(body, &response); err != nil {
return VideoCallbackData{}, fmt.Errorf("failed to unmarshal response: %w", err)
}
if response.Code != 0 {
return VideoCallbackData{}, fmt.Errorf("API error: %s", response.Message)
}
return response.Data, nil
}

1
config/config.yaml Normal file
View File

@@ -0,0 +1 @@

View File

@@ -140,6 +140,7 @@ const types = ref([
{ label: "DALL-E", value: "dalle" },
{ label: "Suno文生歌", value: "suno" },
{ label: "Luma视频", value: "luma" },
{ label: "可灵视频", value: "keling" },
{ label: "Realtime API", value: "realtime" },
{ label: "其他", value: "other" },
]);