luma create and list api is ready

This commit is contained in:
RockYang
2024-09-02 18:08:50 +08:00
parent d70fcbb9f3
commit aa46fdb113
9 changed files with 311 additions and 115 deletions

View File

@@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {

View File

@@ -82,8 +82,8 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
var r RespVo
r, err = s.CreateLuma(task)
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
@@ -96,14 +96,15 @@ func (s *Service) Run() {
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
})
}
}()
}
type RespVo struct {
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
@@ -114,7 +115,7 @@ type RespVo struct {
Channel string `json:"channel,omitempty"`
}
func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
@@ -123,7 +124,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
@@ -133,7 +134,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res RespVo
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
@@ -141,13 +142,17 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
@@ -180,7 +185,7 @@ func (s *Service) CheckTaskNotify() {
func (s *Service) DownloadFiles() {
go func() {
var items []model.SunoJob
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
@@ -188,22 +193,13 @@ func (s *Service) DownloadFiles() {
}
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
logger.Infof("try download video: %s", v.VideoURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
v.CoverURL = coverURL
v.AudioURL = audioURL
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
@@ -217,7 +213,7 @@ func (s *Service) DownloadFiles() {
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.SunoJob
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
@@ -225,60 +221,14 @@ func (s *Service) SyncTaskProgress() {
}
for _, job := range jobs {
task, err := s.QueryTask(job.TaskId, job.Channel)
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
if task.Code != "success" {
logger.Errorf("query task with error: %v", task.Message)
continue
}
logger.Debugf("task: %+v", task)
logger.Debugf("task: %+v", task.Data.Status)
// 任务完成,删除旧任务插入两条新任务
if task.Data.Status == "SUCCESS" {
var jobId = job.Id
var flag = false
tx := s.db.Begin()
for _, v := range task.Data.Data {
job.Id = 0
job.Progress = 102 // 102 表示资源未下载完成
job.Title = v.Title
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl
job.AudioURL = v.AudioUrl
if err = tx.Create(&job).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
break
}
flag = true
}
// 删除旧任务
if flag {
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
continue
}
}
tx.Commit()
} else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
}
}
time.Sleep(time.Second * 10)
@@ -286,42 +236,22 @@ func (s *Service) SyncTaskProgress() {
}()
}
type QueryRespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data struct {
TaskId string `json:"task_id"`
Action string `json:"action"`
Status string `json:"status"`
FailReason string `json:"fail_reason"`
SubmitTime int `json:"submit_time"`
StartTime int `json:"start_time"`
FinishTime int `json:"finish_time"`
Progress string `json:"progress"`
Data []struct {
Id string `json:"id"`
Title string `json:"title"`
Status string `json:"status"`
Metadata struct {
Tags string `json:"tags"`
Type string `json:"type"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Duration float64 `json:"duration"`
ErrorMessage interface{} `json:"error_message"`
} `json:"metadata"`
AudioUrl string `json:"audio_url"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
ModelName string `json:"model_name"`
DisplayName string `json:"display_name"`
ImageLargeUrl string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
} `json:"data"`
} `json:"data"`
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
@@ -329,22 +259,22 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error)
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
return LumaTaskVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return QueryRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil