merge v4.1.1 and fixed conflicts

This commit is contained in:
RockYang
2024-11-13 18:40:04 +08:00
110 changed files with 4820 additions and 1977 deletions

View File

@@ -110,12 +110,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt))
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
}
var user model.User
@@ -145,7 +144,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// get image generation API KEY
var apiKey model.ApiKey
tx = s.db.Where("type", "img").
tx = s.db.Where("type", "dalle").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
@@ -157,6 +156,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
@@ -165,14 +165,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody)
request := s.httpClient.R().SetHeader("Content-Type", "application/json")
if apiKey.Platform == types.Azure.Value {
request = request.SetHeader("api-key", apiKey.Value)
} else {
request = request.SetHeader("Authorization", "Bearer "+apiKey.Value)
}
r, err := request.SetBody(reqBody).SetErrorResult(&errRes).SetSuccessResult(&res).Post(apiKey.ApiURL)
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
SetErrorResult(&errRes).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
@@ -259,7 +258,7 @@ func (s *Service) DownloadImages() {
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
if err != nil {
return "", err
}

View File

@@ -67,25 +67,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
return c.doRequest(body, apiURL)
}
// Blend 融图
@@ -112,23 +94,7 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
}
}
}
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
return c.doRequest(body, apiURL)
}
// SwapFace 换脸
@@ -165,23 +131,7 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
},
"state": "",
}
var res ImageRes
var errRes ErrRes
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
return c.doRequest(body, apiURL)
}
// Upscale 放大指定的图片
@@ -195,24 +145,7 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
return c.doRequest(body, apiURL)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
@@ -226,9 +159,14 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
return c.doRequest(body, apiURL)
}
func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) {
var res ImageRes
var errRes ErrRes
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
@@ -236,7 +174,13 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
errMsg := err.Error()
if r != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Error("请求 API 出错:", string(errStr))
errMsg = errMsg + " " + string(errStr)
}
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
}
if r.IsErrorState() {

View File

@@ -8,7 +8,6 @@ package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
@@ -139,7 +138,7 @@ func (p *ServicePool) DownloadImages() {
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
@@ -188,28 +187,6 @@ func (p *ServicePool) SyncTaskProgress() {
}
for _, job := range jobs {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)
// 退回算力
tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
if tx.Error == nil && tx.RowsAffected > 0 {
var user model.User
p.db.Where("id = ?", job.UserId).First(&user)
p.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}

View File

@@ -29,6 +29,7 @@ type Service struct {
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
retryCount map[uint]int
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
@@ -39,9 +40,12 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
notifyQueue: notifyQueue,
Client: cli,
running: true,
retryCount: make(map[uint]int),
}
}
const failedProgress = 101
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for s.running {
@@ -55,15 +59,20 @@ func (s *Service) Run() {
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
if s.retryCount[task.Id] > 5 {
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
continue
}
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
s.retryCount[task.Id]++
time.Sleep(time.Second)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil {
task.Prompt = content
} else {
@@ -72,7 +81,7 @@ func (s *Service) Run() {
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini")
if err == nil {
task.NegPrompt = content
} else {
@@ -116,7 +125,7 @@ func (s *Service) Run() {
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.Progress = failedProgress
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
@@ -164,7 +173,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"progress": failedProgress,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})

View File

@@ -84,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil {
return "", err
}

View File

@@ -57,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
parse, err := url.Parse(imageURL)
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -69,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
}
if useProxy {
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
} else {
err = utils.DownloadFile(imageURL, filePath, "")
err = utils.DownloadFile(fileURL, filePath, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)

View File

@@ -44,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
}
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -65,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
context.Background(),
s.config.Bucket,
filename,
strings.NewReader(string(imageData)),
int64(len(imageData)),
strings.NewReader(string(fileData)),
int64(len(fileData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err

View File

@@ -93,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
fileData, err = utils.DownloadImage(fileURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
parse, err := url.Parse(imageURL)
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
@@ -113,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra)
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
if err != nil {
return "", err
}

View File

@@ -23,7 +23,7 @@ type File struct {
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutUrlFile(url string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -63,7 +63,7 @@ func (s *Service) Run() {
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
if err == nil {
task.Params.Prompt = content
} else {
@@ -73,7 +73,7 @@ func (s *Service) Run() {
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
if err == nil {
task.Params.NegPrompt = content
} else {

355
api/service/suno/service.go Normal file
View File

@@ -0,0 +1,355 @@
package suno
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
func (s *Service) PushTask(task types.SunoTask) {
logger.Infof("add a new Suno task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.SunoJob
s.db.Where("task_id", "").Find(&jobs)
for _, v := range jobs {
s.PushTask(types.SunoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
Title: v.Title,
RefTaskId: v.RefTaskId,
RefSongId: v.RefSongId,
Prompt: v.Prompt,
Tags: v.Tags,
Model: v.ModelName,
Instrumental: v.Instrumental,
ExtendSecs: v.ExtendSecs,
})
}
logger.Info("Starting Suno job consumer...")
go func() {
for {
var task types.SunoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
r, err := s.Create(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": 101,
})
continue
}
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Data,
"channel": r.Channel,
})
}
}()
}
type RespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data string `json:"data"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"task_id": task.RefTaskId,
"continue_clip_id": task.RefSongId,
"continue_at": task.ExtendSecs,
"make_instrumental": task.Instrumental,
}
// 灵感模式
if task.Type == 1 {
reqBody["gpt_description_prompt"] = task.Prompt
} else { // 自定义模式
reqBody["prompt"] = task.Prompt
reqBody["tags"] = task.Tags
reqBody["mv"] = task.Model
reqBody["title"] = task.Title
}
var res RespVo
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message sd.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.SunoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
v.CoverURL = coverURL
v.AudioURL = audioURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.SunoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
if task.Code != "success" {
logger.Errorf("query task with error: %v", task.Message)
continue
}
logger.Debugf("task: %+v", task.Data.Status)
// 任务完成,删除旧任务插入两条新任务
if task.Data.Status == "SUCCESS" {
var jobId = job.Id
var flag = false
tx := s.db.Begin()
for _, v := range task.Data.Data {
job.Id = 0
job.Progress = 102 // 102 表示资源未下载完成
job.Title = v.Title
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl
job.AudioURL = v.AudioUrl
if err = tx.Create(&job).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
break
}
flag = true
}
// 删除旧任务
if flag {
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
continue
}
}
tx.Commit()
} else if task.Data.FailReason != "" {
job.Progress = 101
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
}
}
time.Sleep(time.Second * 10)
}
}()
}
type QueryRespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data struct {
TaskId string `json:"task_id"`
Action string `json:"action"`
Status string `json:"status"`
FailReason string `json:"fail_reason"`
SubmitTime int `json:"submit_time"`
StartTime int `json:"start_time"`
FinishTime int `json:"finish_time"`
Progress string `json:"progress"`
Data []struct {
Id string `json:"id"`
Title string `json:"title"`
Status string `json:"status"`
Metadata struct {
Tags string `json:"tags"`
Type string `json:"type"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Duration float64 `json:"duration"`
ErrorMessage interface{} `json:"error_message"`
} `json:"metadata"`
AudioUrl string `json:"audio_url"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
ModelName string `json:"model_name"`
DisplayName string `json:"display_name"`
ImageLargeUrl string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
} `json:"data"`
} `json:"data"`
}
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return QueryRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}