mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-27 11:09:23 +08:00
664 lines
19 KiB
Go
664 lines
19 KiB
Go
package video
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"geekai/core/types"
|
||
logger2 "geekai/logger"
|
||
"geekai/service"
|
||
"geekai/service/oss"
|
||
"geekai/store"
|
||
"geekai/store/model"
|
||
"geekai/utils"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/go-redis/redis/v8"
|
||
|
||
"github.com/imroc/req/v3"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
var logger = logger2.GetLogger()
|
||
|
||
type Service struct {
|
||
httpClient *req.Client
|
||
db *gorm.DB
|
||
uploadManager *oss.UploaderManager
|
||
taskQueue *store.RedisQueue
|
||
wsService *service.WebsocketService
|
||
userService *service.UserService
|
||
}
|
||
|
||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service {
|
||
return &Service{
|
||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||
db: db,
|
||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||
wsService: wsService,
|
||
uploadManager: manager,
|
||
userService: userService,
|
||
}
|
||
}
|
||
|
||
func (s *Service) PushTask(task types.VideoTask) {
|
||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||
s.taskQueue.RPush(task)
|
||
}
|
||
|
||
func (s *Service) Run() {
|
||
// 将数据库中未提交的任务加载到队列
|
||
var jobs []model.VideoJob
|
||
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||
for _, v := range jobs {
|
||
var task types.VideoTask
|
||
err := utils.JsonDecode(v.TaskInfo, &task)
|
||
if err != nil {
|
||
logger.Errorf("decode task info with error: %v", err)
|
||
continue
|
||
}
|
||
task.Id = v.Id
|
||
s.PushTask(task)
|
||
}
|
||
logger.Info("Starting Video job consumer...")
|
||
go func() {
|
||
for {
|
||
var task types.VideoTask
|
||
err := s.taskQueue.LPop(&task)
|
||
if err != nil {
|
||
logger.Errorf("taking task with error: %v", err)
|
||
continue
|
||
}
|
||
|
||
if task.Type == types.VideoLuma {
|
||
// translate prompt
|
||
if utils.HasChinese(task.Prompt) {
|
||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||
if err == nil {
|
||
task.Prompt = content
|
||
} else {
|
||
logger.Warnf("error with translate prompt: %v", err)
|
||
}
|
||
}
|
||
var r LumaRespVo
|
||
r, err = s.LumaCreate(task)
|
||
if err != nil {
|
||
logger.Errorf("create task with error: %v", err)
|
||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||
"err_msg": err.Error(),
|
||
"progress": service.FailTaskProgress,
|
||
"cover_url": "/images/failed.jpg",
|
||
}).Error
|
||
if err != nil {
|
||
logger.Errorf("update task with error: %v", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 更新任务信息
|
||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||
"task_id": r.Id,
|
||
"channel": r.Channel,
|
||
"prompt_ext": r.Prompt,
|
||
}).Error
|
||
if err != nil {
|
||
logger.Errorf("update task with error: %v", err)
|
||
s.PushTask(task)
|
||
}
|
||
} else if task.Type == types.VideoKeLing {
|
||
var r KeLingRespVo
|
||
r, err = s.KeLingCreate(task)
|
||
logger.Debugf("ke ling create task result: %+v", r)
|
||
|
||
if err != nil {
|
||
logger.Errorf("create task with error: %v", err)
|
||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||
"err_msg": err.Error(),
|
||
"progress": service.FailTaskProgress,
|
||
"cover_url": "/images/failed.jpg",
|
||
}).Error
|
||
if err != nil {
|
||
logger.Errorf("update task with error: %v", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 更新任务信息
|
||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||
"task_id": r.Data.TaskID,
|
||
"channel": r.Channel,
|
||
"prompt_ext": task.Prompt,
|
||
}).Error
|
||
if err != nil {
|
||
logger.Errorf("update task with error: %v", err)
|
||
s.PushTask(task)
|
||
}
|
||
}
|
||
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (s *Service) DownloadFiles() {
|
||
go func() {
|
||
var items []model.VideoJob
|
||
for {
|
||
res := s.db.Where("progress", 102).Find(&items)
|
||
if res.Error != nil {
|
||
continue
|
||
}
|
||
|
||
for _, v := range items {
|
||
if v.WaterURL == "" {
|
||
continue
|
||
}
|
||
|
||
logger.Infof("try download video: %s", v.WaterURL)
|
||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||
if err != nil {
|
||
logger.Errorf("download video with error: %v", err)
|
||
continue
|
||
}
|
||
logger.Infof("download video success: %s", videoURL)
|
||
v.WaterURL = videoURL
|
||
|
||
if v.VideoURL != "" {
|
||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||
if err != nil {
|
||
logger.Errorf("download video with error: %v", err)
|
||
continue
|
||
}
|
||
}
|
||
logger.Infof("download no water video success: %s", videoURL)
|
||
v.VideoURL = videoURL
|
||
v.Progress = 100
|
||
s.db.Updates(&v)
|
||
|
||
// Convert TaskInfo to VideoTask
|
||
var videoTask types.VideoTask
|
||
if err := json.Unmarshal([]byte(v.TaskInfo), &videoTask); err != nil {
|
||
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
|
||
continue
|
||
}
|
||
|
||
}
|
||
|
||
time.Sleep(time.Second * 10)
|
||
}
|
||
}()
|
||
}
|
||
|
||
// SyncTaskProgress 异步拉取任务
|
||
func (s *Service) SyncTaskProgress() {
|
||
go func() {
|
||
var jobs []model.VideoJob
|
||
for {
|
||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||
if res.Error != nil {
|
||
continue
|
||
}
|
||
|
||
for _, job := range jobs {
|
||
if job.Type == types.VideoLuma {
|
||
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
|
||
if err != nil {
|
||
logger.Errorf("query task with error: %v", err)
|
||
// 更新任务信息
|
||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||
"err_msg": err.Error(),
|
||
"cover_url": "/images/failed.jpg",
|
||
})
|
||
continue
|
||
}
|
||
|
||
logger.Debugf("task: %+v", task)
|
||
if task.State == "completed" { // 更新任务信息
|
||
data := map[string]interface{}{
|
||
"progress": 102, // 102 表示资源未下载完成,
|
||
"water_url": task.Video.Url,
|
||
"raw_data": utils.JsonEncode(task),
|
||
"prompt_ext": task.Prompt,
|
||
"cover_url": task.Thumbnail.Url,
|
||
}
|
||
if task.Video.DownloadUrl != "" {
|
||
data["video_url"] = task.Video.DownloadUrl
|
||
}
|
||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||
if err != nil {
|
||
logger.Errorf("更新数据库失败:%v", err)
|
||
continue
|
||
}
|
||
}
|
||
} else if job.Type == types.VideoKeLing {
|
||
// Convert TaskInfo to VideoTask
|
||
var videoTask types.VideoTask
|
||
if err := json.Unmarshal([]byte(job.TaskInfo), &videoTask); err != nil {
|
||
logger.Errorf("failed to unmarshal task info to VideoTask: %v", err)
|
||
continue
|
||
}
|
||
|
||
// Type assert task.Params to KeLingVideoParams
|
||
paramsMap, ok := videoTask.Params.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// Convert map to KeLingVideoParams
|
||
paramsBytes, err := json.Marshal(paramsMap)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
var params types.KeLingVideoParams
|
||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||
continue
|
||
}
|
||
|
||
task, err := s.QueryKeLingTask(job.TaskId, job.Channel, params.TaskType)
|
||
if err != nil {
|
||
logger.Errorf("query task with error: %v", err)
|
||
// 更新任务信息
|
||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||
"err_msg": err.Error(),
|
||
"cover_url": "/images/failed.jpg",
|
||
})
|
||
continue
|
||
}
|
||
|
||
logger.Debugf("task: %+v", task)
|
||
if task.TaskStatus == "succeed" { // 更新任务信息
|
||
data := map[string]interface{}{
|
||
"progress": 102, // 102 表示资源未下载完成,
|
||
"water_url": task.TaskResult.Videos[0].URL,
|
||
"raw_data": utils.JsonEncode(task),
|
||
"prompt_ext": job.Prompt,
|
||
"cover_url": "",
|
||
}
|
||
if len(task.TaskResult.Videos) > 0 {
|
||
data["video_url"] = task.TaskResult.Videos[0].URL
|
||
}
|
||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||
if err != nil {
|
||
logger.Errorf("更新数据库失败:%v", err)
|
||
continue
|
||
}
|
||
} else if task.TaskStatus == "failed" {
|
||
// 更新任务信息
|
||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||
"progress": service.FailTaskProgress,
|
||
"err_msg": task.TaskStatusMsg,
|
||
"cover_url": "/images/failed.jpg",
|
||
})
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
// 找出失败的任务,并恢复其扣减算力
|
||
s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs)
|
||
for _, job := range jobs {
|
||
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||
Type: types.PowerRefund,
|
||
Model: job.Type,
|
||
Remark: fmt.Sprintf("%s 任务失败,退回算力。任务ID:%s,Err:%s", job.Type, job.TaskId, job.ErrMsg),
|
||
})
|
||
if err != nil {
|
||
continue
|
||
}
|
||
// 更新任务状态
|
||
s.db.Model(&job).UpdateColumn("power", 0)
|
||
}
|
||
time.Sleep(time.Second * 10)
|
||
}
|
||
}()
|
||
}
|
||
|
||
type LumaTaskVo struct {
|
||
Id string `json:"id"`
|
||
Liked interface{} `json:"liked"`
|
||
State string `json:"state"`
|
||
Video struct {
|
||
Url string `json:"url"`
|
||
Width int `json:"width"`
|
||
Height int `json:"height"`
|
||
Thumbnail string `json:"thumbnail"`
|
||
DownloadUrl string `json:"download_url"`
|
||
} `json:"video"`
|
||
Prompt string `json:"prompt"`
|
||
UserId string `json:"user_id"`
|
||
BatchId string `json:"batch_id"`
|
||
Thumbnail struct {
|
||
Url string `json:"url"`
|
||
Width int `json:"width"`
|
||
Height int `json:"height"`
|
||
} `json:"thumbnail"`
|
||
VideoRaw struct {
|
||
Url string `json:"url"`
|
||
Width int `json:"width"`
|
||
Height int `json:"height"`
|
||
} `json:"video_raw"`
|
||
CreatedAt string `json:"created_at"`
|
||
LastFrame struct {
|
||
Url string `json:"url"`
|
||
Width int `json:"width"`
|
||
Height int `json:"height"`
|
||
} `json:"last_frame"`
|
||
}
|
||
|
||
type LumaRespVo struct {
|
||
Id string `json:"id"`
|
||
Prompt string `json:"prompt"`
|
||
State string `json:"state"`
|
||
QueueState interface{} `json:"queue_state"`
|
||
CreatedAt string `json:"created_at"`
|
||
Video interface{} `json:"video"`
|
||
VideoRaw interface{} `json:"video_raw"`
|
||
Liked interface{} `json:"liked"`
|
||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||
Thumbnail interface{} `json:"thumbnail"`
|
||
Channel string `json:"channel,omitempty"`
|
||
}
|
||
|
||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
|
||
// 读取 API KEY
|
||
var apiKey model.ApiKey
|
||
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
|
||
if task.Channel != "" {
|
||
session = session.Where("api_url", task.Channel)
|
||
}
|
||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||
if tx.Error != nil {
|
||
return LumaRespVo{}, errors.New("no available API KEY for Luma")
|
||
}
|
||
|
||
// Type assert task.Params to LumaVideoParams
|
||
paramsMap, ok := task.Params.(map[string]interface{})
|
||
if !ok {
|
||
return LumaRespVo{}, errors.New("invalid params type for Luma video task")
|
||
}
|
||
|
||
// Convert map to LumaVideoParams
|
||
paramsBytes, err := json.Marshal(paramsMap)
|
||
if err != nil {
|
||
return LumaRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
|
||
}
|
||
|
||
var params types.LumaVideoParams
|
||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||
return LumaRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
|
||
}
|
||
|
||
reqBody := map[string]interface{}{
|
||
"user_prompt": task.Prompt,
|
||
"expand_prompt": params.PromptOptimize,
|
||
"loop": params.Loop,
|
||
"image_url": params.StartImgURL, // 图生视频
|
||
"image_end_url": params.EndImgURL, // 图生视频
|
||
}
|
||
|
||
var res LumaRespVo
|
||
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
|
||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||
r, err := req.C().R().
|
||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||
SetBody(reqBody).
|
||
Post(apiURL)
|
||
if err != nil {
|
||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||
}
|
||
|
||
if r.StatusCode != 200 && r.StatusCode != 201 {
|
||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||
}
|
||
|
||
body, _ := io.ReadAll(r.Body)
|
||
err = json.Unmarshal(body, &res)
|
||
if err != nil {
|
||
return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||
}
|
||
|
||
// update the last_use_at for api key
|
||
apiKey.LastUsedAt = time.Now().Unix()
|
||
session.Updates(&apiKey)
|
||
res.Channel = apiKey.ApiURL
|
||
return res, nil
|
||
}
|
||
|
||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||
// 读取 API KEY
|
||
var apiKey model.ApiKey
|
||
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
|
||
Where("api_url", channel).
|
||
Where("enabled", true).
|
||
Order("last_used_at DESC").First(&apiKey).Error
|
||
if err != nil {
|
||
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
|
||
}
|
||
|
||
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
|
||
var res LumaTaskVo
|
||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||
|
||
if err != nil {
|
||
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||
}
|
||
defer r.Body.Close()
|
||
|
||
if r.StatusCode != 200 {
|
||
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
|
||
}
|
||
|
||
body, _ := io.ReadAll(r.Body)
|
||
err = json.Unmarshal(body, &res)
|
||
if err != nil {
|
||
return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||
}
|
||
|
||
return res, nil
|
||
}
|
||
|
||
type KeLingRespVo struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
RequestID string `json:"request_id"`
|
||
Data struct {
|
||
TaskID string `json:"task_id"`
|
||
TaskStatus string `json:"task_status"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
UpdatedAt int64 `json:"updated_at"`
|
||
} `json:"data"`
|
||
Channel string `json:"channel,omitempty"`
|
||
}
|
||
|
||
func (s *Service) KeLingCreate(task types.VideoTask) (KeLingRespVo, error) {
|
||
var apiKey model.ApiKey
|
||
session := s.db.Session(&gorm.Session{}).Where("type", "keling").Where("enabled", true)
|
||
if task.Channel != "" {
|
||
session = session.Where("api_url", task.Channel)
|
||
}
|
||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||
if tx.Error != nil {
|
||
return KeLingRespVo{}, errors.New("no available API KEY for keling")
|
||
}
|
||
|
||
// Type assert task.Params to KeLingVideoParams
|
||
paramsMap, ok := task.Params.(map[string]interface{})
|
||
if !ok {
|
||
return KeLingRespVo{}, errors.New("invalid params type for KeLing video task")
|
||
}
|
||
|
||
// Convert map to KeLingVideoParams
|
||
paramsBytes, err := json.Marshal(paramsMap)
|
||
if err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to marshal params: %v", err)
|
||
}
|
||
|
||
var params types.KeLingVideoParams
|
||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to unmarshal params: %v", err)
|
||
}
|
||
|
||
// 2. 构建API请求参数
|
||
payload := map[string]interface{}{
|
||
"model_name": params.Model,
|
||
"prompt": task.Prompt,
|
||
"negative_prompt": params.NegPrompt,
|
||
"cfg_scale": params.CfgScale,
|
||
"mode": params.Mode,
|
||
"aspect_ratio": params.AspectRatio,
|
||
"duration": params.Duration,
|
||
}
|
||
|
||
// 只有当 CameraControl 的类型不为空时,才处理摄像机控制参数
|
||
if params.CameraControl.Type != "" {
|
||
cameraControl := map[string]interface{}{
|
||
"type": params.CameraControl.Type,
|
||
}
|
||
|
||
// 只有在 simple 类型时才添加 config 参数
|
||
if params.CameraControl.Type == "simple" {
|
||
cameraControl["config"] = params.CameraControl.Config
|
||
}
|
||
|
||
payload["camera_control"] = cameraControl
|
||
}
|
||
|
||
// 处理图生视频
|
||
if params.TaskType == "image2video" {
|
||
payload["image"] = params.Image
|
||
payload["image_tail"] = params.ImageTail
|
||
}
|
||
|
||
jsonPayload, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to marshal payload: %v", err)
|
||
}
|
||
|
||
// 3. 准备HTTP请求
|
||
url := fmt.Sprintf("%s/kling/v1/videos/%s", apiKey.ApiURL, params.TaskType)
|
||
req, err := http.NewRequest("POST", url, bytes.NewReader(jsonPayload))
|
||
if err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to create request: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
// 4. 发送请求
|
||
client := &http.Client{Timeout: time.Duration(30) * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to send request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 5. 处理响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to read response: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return KeLingRespVo{}, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var apiResponse = KeLingRespVo{}
|
||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||
return KeLingRespVo{}, fmt.Errorf("failed to parse response: %v", err)
|
||
}
|
||
// 设置 API 通道
|
||
apiResponse.Channel = apiKey.ApiURL
|
||
return apiResponse, nil
|
||
}
|
||
|
||
// VideoCallbackData 表示视频生成任务的回调数据
|
||
type VideoCallbackData struct {
|
||
TaskID string `json:"task_id"`
|
||
TaskStatus string `json:"task_status"`
|
||
TaskStatusMsg string `json:"task_status_msg"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
UpdatedAt int64 `json:"updated_at"`
|
||
TaskResult TaskResult `json:"task_result"`
|
||
}
|
||
|
||
type TaskResult struct {
|
||
Images []CallBackImageResult `json:"images,omitempty"`
|
||
Videos []CallBackVideoResult `json:"videos,omitempty"`
|
||
}
|
||
|
||
type CallBackImageResult struct {
|
||
Index int `json:"index"`
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
type CallBackVideoResult struct {
|
||
ID string `json:"id"`
|
||
URL string `json:"url"`
|
||
Duration string `json:"duration"`
|
||
}
|
||
|
||
func (s *Service) QueryKeLingTask(taskId string, channel string, action string) (VideoCallbackData, error) {
|
||
var apiKey model.ApiKey
|
||
err := s.db.Session(&gorm.Session{}).Where("type", "keling").
|
||
//Where("api_url", channel).
|
||
Where("enabled", true).
|
||
Order("last_used_at DESC").First(&apiKey).Error
|
||
if err != nil {
|
||
return VideoCallbackData{}, errors.New("no available API KEY for keling")
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/kling/v1/videos/%s/%s", apiKey.ApiURL, action, taskId)
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return VideoCallbackData{}, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+apiKey.Value)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{}
|
||
res, err := client.Do(req)
|
||
if err != nil {
|
||
return VideoCallbackData{}, fmt.Errorf("failed to execute request: %w", err)
|
||
}
|
||
defer res.Body.Close()
|
||
|
||
if res.StatusCode != http.StatusOK {
|
||
return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode)
|
||
}
|
||
|
||
body, err := io.ReadAll(res.Body)
|
||
if err != nil {
|
||
return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err)
|
||
}
|
||
|
||
var response struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Data VideoCallbackData `json:"data"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &response); err != nil {
|
||
return VideoCallbackData{}, fmt.Errorf("failed to unmarshal response: %w", err)
|
||
}
|
||
|
||
if response.Code != 0 {
|
||
return VideoCallbackData{}, fmt.Errorf("API error: %s", response.Message)
|
||
}
|
||
|
||
return response.Data, nil
|
||
}
|