From a16ef6476d3b4c010e2b466f980856580ce9761c Mon Sep 17 00:00:00 2001 From: RockYang Date: Fri, 30 Aug 2024 18:12:14 +0800 Subject: [PATCH] add luma api service --- api/core/types/task.go | 28 ++- api/main.go | 2 +- api/service/suno/service.go | 2 +- api/service/video/luma.go | 351 +++++++++++++++++++++++++++++++++++ api/store/model/video_job.go | 27 +++ api/store/vo/suno_job.go | 4 - api/store/vo/video_job.go | 20 ++ database/update-v4.1.3.sql | 27 ++- 8 files changed, 453 insertions(+), 8 deletions(-) create mode 100644 api/service/video/luma.go create mode 100644 api/store/model/video_job.go create mode 100644 api/store/vo/video_job.go diff --git a/api/core/types/task.go b/api/core/types/task.go index 2affae7d..d41e592a 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -85,7 +85,6 @@ type SunoTask struct { Channel string `json:"channel"` UserId int `json:"user_id"` Type int `json:"type"` - TaskId string `json:"task_id"` Title string `json:"title"` RefTaskId string `json:"ref_task_id,omitempty"` RefSongId string `json:"ref_song_id,omitempty"` @@ -97,3 +96,30 @@ type SunoTask struct { SongId string `json:"song_id,omitempty"` // 合并歌曲ID AudioURL string `json:"audio_url"` // 用户上传音频地址 } + +const ( + VideoLuma = "luma" + VideoRunway = "runway" + VideoCog = "cog" +) + +type VideoTask struct { + Id uint `json:"id"` + Channel string `json:"channel"` + UserId int `json:"user_id"` + Type string `json:"type"` + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + Params VideoParams `json:"params"` +} + +type VideoParams struct { + PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词 + Loop bool `json:"loop"` // 是否循环参考图 + StartImgURL string `json:"start_img_url"` // 第一帧参考图地址 + EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址 + Model string `json:"model"` // 使用哪个模型生成视频 + Radio string `json:"radio"` // 视频尺寸 + Style string `json:"style"` // 风格 + Duration int `json:"duration"` // 视频时长(秒) +} diff --git a/api/main.go b/api/main.go index 70611a6b..31dfc0f6 100644 --- a/api/main.go +++ b/api/main.go @@ -199,7 +199,7 @@ func main() { s.Run() s.SyncTaskProgress() s.CheckTaskNotify() - s.DownloadImages() + s.DownloadFiles() }), fx.Provide(payment.NewAlipayService), diff --git a/api/service/suno/service.go b/api/service/suno/service.go index a49bcb47..53b213ef 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -279,7 +279,7 @@ func (s *Service) CheckTaskNotify() { }() } -func (s *Service) DownloadImages() { +func (s *Service) DownloadFiles() { go func() { var items []model.SunoJob for { diff --git a/api/service/video/luma.go b/api/service/video/luma.go new file mode 100644 index 00000000..c94d2565 --- /dev/null +++ b/api/service/video/luma.go @@ -0,0 +1,351 @@ +package video + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "github.com/go-redis/redis/v8" + "io" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +func (s *Service) PushTask(task types.VideoTask) { + logger.Infof("add a new Video task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + // 将数据库中未提交的人物加载到队列 + var jobs []model.VideoJob + s.db.Where("task_id", "").Find(&jobs) + for _, v := range jobs { + var params types.VideoParams + if err := utils.JsonDecode(v.Params, ¶ms); err != nil { + logger.Errorf("unmarshal params failed: %v", err) + continue + } + s.PushTask(types.VideoTask{ + Id: v.Id, + Channel: v.Channel, + UserId: v.UserId, + Type: v.Type, + TaskId: v.TaskId, + Prompt: v.Prompt, + Params: params, + }) + } + logger.Info("Starting Video job consumer...") + go func() { + for { + var task types.VideoTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + var r RespVo + r, err = s.CreateLuma(task) + if err != nil { + logger.Errorf("create task with error: %v", err) + s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "err_msg": err.Error(), + "progress": service.FailTaskProgress, + }) + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) + continue + } + + // 更新任务信息 + s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "task_id": r.Id, + "channel": r.Channel, + }) + } + }() +} + +type RespVo struct { + Id string `json:"id"` + Prompt string `json:"prompt"` + State string `json:"state"` + CreatedAt time.Time `json:"created_at"` + Video interface{} `json:"video"` + Liked interface{} `json:"liked"` + EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` + Channel string `json:"channel,omitempty"` +} + +func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true) + if task.Channel != "" { + session = session.Where("api_url", task.Channel) + } + tx := session.Order("last_used_at DESC").First(&apiKey) + if tx.Error != nil { + return RespVo{}, errors.New("no available API KEY for Suno") + } + + reqBody := map[string]interface{}{ + "user_prompt": task.Prompt, + "expand_prompt": task.Params.PromptOptimize, + "loop": task.Params.Loop, + "image_url": task.Params.StartImgURL, + "image_end_url": task.Params.EndImgURL, + } + var res RespVo + apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL) + logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). + Post(apiURL) + if err != nil { + return RespVo{}, fmt.Errorf("请求 API 出错:%v", err) + } + + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + // update the last_use_at for api key + apiKey.LastUsedAt = time.Now().Unix() + session.Updates(&apiKey) + res.Channel = apiKey.ApiURL + return res, nil +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running Suno task notify checking ...") + for { + var message service.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadFiles() { + go func() { + var items []model.SunoJob + for { + res := s.db.Where("progress", 102).Find(&items) + if res.Error != nil { + continue + } + + for _, v := range items { + // 下载图片和音频 + logger.Infof("try download cover image: %s", v.CoverURL) + coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true) + if err != nil { + logger.Errorf("download image with error: %v", err) + continue + } + + logger.Infof("try download audio: %s", v.AudioURL) + audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true) + if err != nil { + logger.Errorf("download audio with error: %v", err) + continue + } + v.CoverURL = coverURL + v.AudioURL = audioURL + v.Progress = 100 + s.db.Updates(&v) + s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) + } + + time.Sleep(time.Second * 10) + } + }() +} + +// SyncTaskProgress 异步拉取任务 +func (s *Service) SyncTaskProgress() { + go func() { + var jobs []model.SunoJob + for { + res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) + if res.Error != nil { + continue + } + + for _, job := range jobs { + task, err := s.QueryTask(job.TaskId, job.Channel) + if err != nil { + logger.Errorf("query task with error: %v", err) + continue + } + + if task.Code != "success" { + logger.Errorf("query task with error: %v", task.Message) + continue + } + + logger.Debugf("task: %+v", task.Data.Status) + // 任务完成,删除旧任务插入两条新任务 + if task.Data.Status == "SUCCESS" { + var jobId = job.Id + var flag = false + tx := s.db.Begin() + for _, v := range task.Data.Data { + job.Id = 0 + job.Progress = 102 // 102 表示资源未下载完成 + job.Title = v.Title + job.SongId = v.Id + job.Duration = int(v.Metadata.Duration) + job.Prompt = v.Metadata.Prompt + job.Tags = v.Metadata.Tags + job.ModelName = v.ModelName + job.RawData = utils.JsonEncode(v) + job.CoverURL = v.ImageLargeUrl + job.AudioURL = v.AudioUrl + + if err = tx.Create(&job).Error; err != nil { + logger.Error("create job with error: %v", err) + tx.Rollback() + break + } + flag = true + } + + // 删除旧任务 + if flag { + if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil { + logger.Error("create job with error: %v", err) + tx.Rollback() + continue + } + } + tx.Commit() + + } else if task.Data.FailReason != "" { + job.Progress = service.FailTaskProgress + job.ErrMsg = task.Data.FailReason + s.db.Updates(&job) + s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + } + } + + time.Sleep(time.Second * 10) + } + }() +} + +type QueryRespVo struct { + Code string `json:"code"` + Message string `json:"message"` + Data struct { + TaskId string `json:"task_id"` + Action string `json:"action"` + Status string `json:"status"` + FailReason string `json:"fail_reason"` + SubmitTime int `json:"submit_time"` + StartTime int `json:"start_time"` + FinishTime int `json:"finish_time"` + Progress string `json:"progress"` + Data []struct { + Id string `json:"id"` + Title string `json:"title"` + Status string `json:"status"` + Metadata struct { + Tags string `json:"tags"` + Type string `json:"type"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` + Duration float64 `json:"duration"` + ErrorMessage interface{} `json:"error_message"` + } `json:"metadata"` + AudioUrl string `json:"audio_url"` + ImageUrl string `json:"image_url"` + VideoUrl string `json:"video_url"` + ModelName string `json:"model_name"` + DisplayName string `json:"display_name"` + ImageLargeUrl string `json:"image_large_url"` + MajorModelVersion string `json:"major_model_version"` + } `json:"data"` + } `json:"data"` +} + +func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + tx := s.db.Session(&gorm.Session{}).Where("type", "suno"). + Where("api_url", channel). + Where("enabled", true). + Order("last_used_at DESC").First(&apiKey) + if tx.Error != nil { + return QueryRespVo{}, errors.New("no available API KEY for Suno") + } + + apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId) + var res QueryRespVo + r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) + + if err != nil { + return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err) + } + + defer r.Body.Close() + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + return res, nil +} diff --git a/api/store/model/video_job.go b/api/store/model/video_job.go new file mode 100644 index 00000000..5dc7cb3e --- /dev/null +++ b/api/store/model/video_job.go @@ -0,0 +1,27 @@ +package model + +import "time" + +type VideoJob struct { + Id uint `gorm:"primarykey;column:id"` + UserId int + Channel string // 频道 + Type string // luma,runway,cog + TaskId string + Prompt string // 提示词 + PromptExt string // 优化后提示词 + CoverURL string // 封面图 URL + VideoURL string // 无水印视频 URL + WaterURL string // 有水印视频 URL + Progress int // 任务进度 + Publish bool // 是否发布 + ErrMsg string // 错误信息 + RawData string // 原始数据 json + Power int // 消耗算力 + Params string // 任务参数 + CreatedAt time.Time +} + +func (VideoJob) TableName() string { + return "chatgpt_video_jobs" +} diff --git a/api/store/vo/suno_job.go b/api/store/vo/suno_job.go index 97a18a3a..70dca573 100644 --- a/api/store/vo/suno_job.go +++ b/api/store/vo/suno_job.go @@ -28,7 +28,3 @@ type SunoJob struct { PlayTimes int `json:"play_times"` // 播放次数 CreatedAt int64 `json:"created_at"` } - -func (SunoJob) TableName() string { - return "chatgpt_suno_jobs" -} diff --git a/api/store/vo/video_job.go b/api/store/vo/video_job.go new file mode 100644 index 00000000..b7530132 --- /dev/null +++ b/api/store/vo/video_job.go @@ -0,0 +1,20 @@ +package vo + +type VideoJob struct { + Id uint `json:"id"` + UserId int `json:"user_id"` + Channel string `json:"channel"` + Type string `json:"type"` + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + PromptExt string `json:"prompt_ext"` // 提示词 + CoverURL string `json:"cover_url"` // 封面图 URL + VideoURL string `json:"video_url"` // 无水印视频 URL + WaterURL string `json:"water_url"` // 有水印视频 URL + Progress int `json:"progress"` // 任务进度 + Publish bool `json:"publish"` // 是否发布 + ErrMsg string `json:"err_msg"` // 错误信息 + RawData map[string]interface{} `json:"raw_data"` // 原始数据 json + Power int `json:"power"` // 消耗算力 + CreatedAt int64 `json:"created_at"` +} diff --git a/database/update-v4.1.3.sql b/database/update-v4.1.3.sql index 05f5cd3f..22dda9ba 100644 --- a/database/update-v4.1.3.sql +++ b/database/update-v4.1.3.sql @@ -1,2 +1,27 @@ ALTER TABLE `chatgpt_users` ADD `mobile` CHAR(11) NULL COMMENT '手机号' AFTER `username`; -ALTER TABLE `chatgpt_users` ADD `email` VARCHAR(50) NULL COMMENT '邮箱地址' AFTER `mobile`; \ No newline at end of file +ALTER TABLE `chatgpt_users` ADD `email` VARCHAR(50) NULL COMMENT '邮箱地址' AFTER `mobile`; + +CREATE TABLE `chatgpt_video_jobs` ( + `id` int NOT NULL, + `user_id` int NOT NULL COMMENT '用户 ID', + `channel` varchar(100) NOT NULL COMMENT '渠道', + `task_id` varchar(100) NOT NULL COMMENT '任务 ID', + `type` varchar(20) DEFAULT NULL COMMENT '任务类型,luma,runway,cogvideo', + `prompt` varchar(2000) NOT NULL COMMENT '提示词', + `prompt_ext` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '优化后提示词', + `cover_url` varchar(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '封面图地址', + `video_url` varchar(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '视频地址', + `water_url` varchar(512) DEFAULT NULL COMMENT '带水印的视频地址', + `progress` smallint DEFAULT '0' COMMENT '任务进度', + `publish` tinyint(1) NOT NULL COMMENT '是否发布', + `err_msg` varchar(255) DEFAULT NULL COMMENT '错误信息', + `raw_data` text COMMENT '原始数据', + `power` smallint NOT NULL DEFAULT '0' COMMENT '消耗算力', + `created_at` datetime NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='MidJourney 任务表'; + +ALTER TABLE `chatgpt_video_jobs`ADD PRIMARY KEY (`id`); + +ALTER TABLE `chatgpt_video_jobs` MODIFY `id` int NOT NULL AUTO_INCREMENT; + +ALTER TABLE `chatgpt_video_jobs` ADD `params` VARCHAR(512) NULL COMMENT '参数JSON' AFTER `raw_data`; \ No newline at end of file