mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-09 07:33:41 +08:00
@@ -19,3 +19,22 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
Init(info *relaycommon.TaskRelayInfo)
|
||||
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
|
||||
|
||||
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
|
||||
|
||||
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
}
|
||||
|
||||
@@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.BuildRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(requestBody), nil
|
||||
}
|
||||
|
||||
err = a.BuildRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
172
relay/channel/task/suno/adaptor.go
Normal file
172
relay/channel/task/suno/adaptor.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package suno
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
action := strings.ToUpper(c.Param("action"))
|
||||
|
||||
var sunoRequest *dto.SunoSubmitReq
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
err = actionValidate(c, sunoRequest, action)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if sunoRequest.ContinueClipId != "" {
|
||||
if sunoRequest.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
info.OriginTaskID = sunoRequest.TaskID
|
||||
}
|
||||
|
||||
info.Action = action
|
||||
c.Set("task_request", sunoRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
baseURL := info.BaseUrl
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data, err := json.Marshal(sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var sunoResponse dto.TaskResponse[string]
|
||||
err = json.Unmarshal(responseBody, &sunoResponse)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !sunoResponse.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
return sunoResponse.Data, nil, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task error: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer req.Body.Close()
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
|
||||
switch action {
|
||||
case constant.SunoActionMusic:
|
||||
if sunoRequest.Mv == "" {
|
||||
sunoRequest.Mv = "chirp-v3-0"
|
||||
}
|
||||
case constant.SunoActionLyrics:
|
||||
if sunoRequest.Prompt == "" {
|
||||
err = fmt.Errorf("prompt_empty")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("invalid_action")
|
||||
}
|
||||
return
|
||||
}
|
||||
7
relay/channel/task/suno/models.go
Normal file
7
relay/channel/task/suno/models.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package suno
|
||||
|
||||
var ModelList = []string{
|
||||
"suno_music", "suno_lyrics",
|
||||
}
|
||||
|
||||
var ChannelName = "suno"
|
||||
@@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
StartTime time.Time
|
||||
ApiType int
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiKey string
|
||||
BaseUrl string
|
||||
|
||||
Action string
|
||||
OriginTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
}
|
||||
|
||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
StartTime: startTime,
|
||||
ApiType: apiType,
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package constant
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
@@ -26,6 +29,9 @@ const (
|
||||
RelayModeMidjourneyModal
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeSwapFace
|
||||
RelayModeSunoFetch
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
@@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelaySuno(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
|
||||
relayMode = RelayModeSunoFetch
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
|
||||
relayMode = RelayModeSunoFetchByID
|
||||
} else if strings.Contains(path, "/submit/") {
|
||||
relayMode = RelayModeSunoSubmit
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
commonconstant "one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/aws"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/task/suno"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
@@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||
switch platform {
|
||||
//case constant.APITypeAIProxyLibrary:
|
||||
// return &aiproxy.Adaptor{}
|
||||
case commonconstant.TaskPlatformSuno:
|
||||
return &suno.TaskAdaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
242
relay/relay_task.go
Normal file
242
relay/relay_task.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
/*
|
||||
Task 任务通过平台、Action 区分任务
|
||||
*/
|
||||
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
relayInfo := relaycommon.GenTaskRelayInfo(c)
|
||||
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
// get & validate taskRequest 获取并验证文本请求
|
||||
taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
|
||||
if taskErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||
modelPrice, success := common.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
if userQuota-quota < 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if relayInfo.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != relayInfo.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
relayInfo.BaseUrl = channel.GetBaseURL()
|
||||
relayInfo.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// do request
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// handle response
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if taskErr != nil {
|
||||
return
|
||||
}
|
||||
relayInfo.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
}
|
||||
|
||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
respBuilder, ok := fetchRespBuilders[relayMode]
|
||||
if !ok {
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
respBody, taskErr := respBuilder(c)
|
||||
if taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
userId := c.GetInt("id")
|
||||
var condition = struct {
|
||||
IDs []any `json:"ids"`
|
||||
Action string `json:"action"`
|
||||
}{}
|
||||
err := c.BindJSON(&condition)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var tasks []any
|
||||
if len(condition.IDs) > 0 {
|
||||
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
for _, task := range taskModels {
|
||||
tasks = append(tasks, TaskModel2Dto(task))
|
||||
}
|
||||
} else {
|
||||
tasks = make([]any, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
||||
Code: "success",
|
||||
Data: tasks,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("id")
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||
return &dto.TaskDto{
|
||||
TaskID: task.TaskID,
|
||||
Action: task.Action,
|
||||
Status: string(task.Status),
|
||||
FailReason: task.FailReason,
|
||||
SubmitTime: task.SubmitTime,
|
||||
StartTime: task.StartTime,
|
||||
FinishTime: task.FinishTime,
|
||||
Progress: task.Progress,
|
||||
Data: task.Data,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user