mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 10:43:44 +08:00
add put url file for oss interface
This commit is contained in:
@@ -166,7 +166,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
Style: task.Style,
|
||||
Quality: task.Quality,
|
||||
}
|
||||
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiURL, apiKey.Value, reqBody)
|
||||
logger.Infof("Sending %s request, Channel:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
@@ -259,7 +259,7 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func (p *ServicePool) DownloadImages() {
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
|
||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
|
||||
@@ -84,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(imageURL)
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -69,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
}
|
||||
|
||||
if useProxy {
|
||||
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
|
||||
err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
|
||||
} else {
|
||||
err = utils.DownloadFile(imageURL, filePath, "")
|
||||
err = utils.DownloadFile(fileURL, filePath, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
|
||||
@@ -44,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -65,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
filename,
|
||||
strings.NewReader(string(imageData)),
|
||||
int64(len(imageData)),
|
||||
strings.NewReader(string(fileData)),
|
||||
int64(len(fileData)),
|
||||
minio.PutObjectOptions{ContentType: "image/png"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -93,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
fileData, err = utils.DownloadImage(fileURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
parse, err := url.Parse(imageURL)
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ type File struct {
|
||||
}
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutImg(imageURL string, useProxy bool) (string, error)
|
||||
PutUrlFile(url string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package dalle
|
||||
package suno
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
@@ -8,11 +8,17 @@ package dalle
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -57,10 +63,230 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err := s.Create(task)
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
s.db.UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": 101,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Data,
|
||||
"channel": r.Channel,
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) Create(task types.SunoTask) {
|
||||
|
||||
type RespVo struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"task_id": task.TaskId,
|
||||
"continue_clip_id": task.RefSongId,
|
||||
"continue_at": task.ExtendSecs,
|
||||
"make_instrumental": task.Instrumental,
|
||||
}
|
||||
// 灵感模式
|
||||
if task.Type == 1 {
|
||||
reqBody["gpt_description_prompt"] = task.Prompt
|
||||
} else { // 自定义模式
|
||||
reqBody["prompt"] = task.Lyrics
|
||||
reqBody["tags"] = task.Tags
|
||||
reqBody["mv"] = task.Model
|
||||
reqBody["title"] = task.Title
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%s", string(body))
|
||||
}
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.SunoJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
task, err := s.QueryTask(job.TaskId, job.Channel)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if task.Code != "success" {
|
||||
logger.Errorf("query task with error: %v", task.Message)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task.Data.Status)
|
||||
// 任务完成,删除旧任务插入两条新任务
|
||||
if task.Data.Status == "SUCCESS" {
|
||||
var jobId = job.Id
|
||||
var flag = false
|
||||
tx := s.db.Begin()
|
||||
for _, v := range task.Data.Data {
|
||||
job.Id = 0
|
||||
job.Progress = 100
|
||||
job.Title = v.Title
|
||||
job.SongId = v.Id
|
||||
job.Duration = int(v.Metadata.Duration)
|
||||
job.Prompt = v.Metadata.Prompt
|
||||
job.Tags = v.Metadata.Tags
|
||||
job.ModelName = v.ModelName
|
||||
job.RawData = utils.JsonEncode(v)
|
||||
|
||||
// 下载图片和音频
|
||||
thumbURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageLargeUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
}
|
||||
job.ThumbImgURL = thumbURL
|
||||
job.CoverImgURL = coverURL
|
||||
job.AudioURL = audioURL
|
||||
|
||||
if err = tx.Create(&job).Error; err != nil {
|
||||
logger.Error("create job with error: %v", err)
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
flag = true
|
||||
}
|
||||
|
||||
// 删除旧任务
|
||||
if flag {
|
||||
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
|
||||
logger.Error("create job with error: %v", err)
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = 101
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type QueryRespVo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
SubmitTime int `json:"submit_time"`
|
||||
StartTime int `json:"start_time"`
|
||||
FinishTime int `json:"finish_time"`
|
||||
Progress string `json:"progress"`
|
||||
Data []struct {
|
||||
Id string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Status string `json:"status"`
|
||||
Metadata struct {
|
||||
Tags string `json:"tags"`
|
||||
Type string `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
Duration float64 `json:"duration"`
|
||||
ErrorMessage interface{} `json:"error_message"`
|
||||
} `json:"metadata"`
|
||||
AudioUrl string `json:"audio_url"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
ModelName string `json:"model_name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
ImageLargeUrl string `json:"image_large_url"`
|
||||
MajorModelVersion string `json:"major_model_version"`
|
||||
} `json:"data"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return QueryRespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
|
||||
var res QueryRespVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
|
||||
defer r.Body.Close()
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return QueryRespVo{}, fmt.Errorf("解析API数据失败:%s", string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user