mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: suno api 支持
feat: 调试 suno feat: 补充suno 文档
This commit is contained in:
		@@ -47,6 +47,11 @@
 | 
			
		||||
    2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
 | 
			
		||||
    3. 选择你的bot,然后输入http(s)://你的网站地址/login
 | 
			
		||||
    4. Telegram Bot 名称是bot username 去掉@后的字符串
 | 
			
		||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
 | 
			
		||||
    + [x] /suno/submit/music
 | 
			
		||||
    + [x] /suno/submit/lyrics
 | 
			
		||||
    + [x] /suno/fetch
 | 
			
		||||
    + [x] /suno/fetch/:id
 | 
			
		||||
 | 
			
		||||
## 模型支持
 | 
			
		||||
此版本额外支持以下模型:
 | 
			
		||||
@@ -57,6 +62,7 @@
 | 
			
		||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
 | 
			
		||||
6. [零一万物](https://platform.lingyiwanwu.com/)
 | 
			
		||||
7. 自定义渠道,支持填入完整调用地址
 | 
			
		||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
 | 
			
		||||
 | 
			
		||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
 | 
			
		||||
 | 
			
		||||
@@ -105,6 +111,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
 | 
			
		||||
## Midjourney接口设置文档
 | 
			
		||||
[对接文档](Midjourney.md)
 | 
			
		||||
 | 
			
		||||
## Suno接口设置文档
 | 
			
		||||
[对接文档](Suno.md)
 | 
			
		||||
 | 
			
		||||
## 交流群
 | 
			
		||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										37
									
								
								Suno.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								Suno.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
			
		||||
# Suno API文档
 | 
			
		||||
 | 
			
		||||
**简介**:Suno API文档
 | 
			
		||||
 | 
			
		||||
## 模型列表
 | 
			
		||||
 | 
			
		||||
### Suno API支持
 | 
			
		||||
 | 
			
		||||
- suno_music (自定义模式、灵感模式、续写)
 | 
			
		||||
- suno_lyrics (生成歌词)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "suno_music": 0.3,
 | 
			
		||||
  "suno_lyrics": 0.01
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 渠道设置
 | 
			
		||||
 | 
			
		||||
### 对接 Suno API
 | 
			
		||||
 | 
			
		||||
1.
 | 
			
		||||
部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
 | 
			
		||||
 | 
			
		||||
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
 | 
			
		||||
   ,模型请参考上方模型列表
 | 
			
		||||
3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
 | 
			
		||||
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
 | 
			
		||||
 | 
			
		||||
### 对接上游new api
 | 
			
		||||
 | 
			
		||||
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
 | 
			
		||||
2. **代理**填写上游new api的地址,例如:http://localhost:3000
 | 
			
		||||
3. 密钥填写上游new api的密钥
 | 
			
		||||
@@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 | 
			
		||||
var DisplayInCurrencyEnabled = true
 | 
			
		||||
var DisplayTokenStatEnabled = true
 | 
			
		||||
var DrawingEnabled = true
 | 
			
		||||
var TaskEnabled = true
 | 
			
		||||
var DataExportEnabled = true
 | 
			
		||||
var DataExportInterval = 5         // unit: minute
 | 
			
		||||
var DataExportDefaultTime = "hour" // unit: minute
 | 
			
		||||
@@ -208,8 +209,10 @@ const (
 | 
			
		||||
	ChannelTypeAws            = 33
 | 
			
		||||
	ChannelTypeCohere         = 34
 | 
			
		||||
	ChannelTypeMiniMax        = 35
 | 
			
		||||
	ChannelTypeSunoAPI        = 36
 | 
			
		||||
 | 
			
		||||
	ChannelTypeDummy // this one is only for count, do not add any channel after this
 | 
			
		||||
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
@@ -249,4 +252,5 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                                          //33
 | 
			
		||||
	"https://api.cohere.ai",                     //34
 | 
			
		||||
	"https://api.minimax.chat",                  //35
 | 
			
		||||
	"",                                          //36
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										18
									
								
								constant/task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								constant/task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
package constant
 | 
			
		||||
 | 
			
		||||
type TaskPlatform string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TaskPlatformSuno       TaskPlatform = "suno"
 | 
			
		||||
	TaskPlatformMidjourney              = "mj"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	SunoActionMusic  = "MUSIC"
 | 
			
		||||
	SunoActionLyrics = "LYRICS"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var SunoModel2Action = map[string]string{
 | 
			
		||||
	"suno_music":  SunoActionMusic,
 | 
			
		||||
	"suno_lyrics": SunoActionLyrics,
 | 
			
		||||
}
 | 
			
		||||
@@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 | 
			
		||||
	if channel.Type == common.ChannelTypeMidjourney {
 | 
			
		||||
		return errors.New("midjourney channel test is not supported"), nil
 | 
			
		||||
	}
 | 
			
		||||
	if channel.Type == common.ChannelTypeSunoAPI {
 | 
			
		||||
		return errors.New("suno channel test is not supported"), nil
 | 
			
		||||
	}
 | 
			
		||||
	w := httptest.NewRecorder()
 | 
			
		||||
	c, _ := gin.CreateTestContext(w)
 | 
			
		||||
	c.Request = &http.Request{
 | 
			
		||||
 
 | 
			
		||||
@@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) {
 | 
			
		||||
			"display_in_currency":      common.DisplayInCurrencyEnabled,
 | 
			
		||||
			"enable_batch_update":      common.BatchUpdateEnabled,
 | 
			
		||||
			"enable_drawing":           common.DrawingEnabled,
 | 
			
		||||
			"enable_task":              common.TaskEnabled,
 | 
			
		||||
			"enable_data_export":       common.DataExportEnabled,
 | 
			
		||||
			"data_export_default_time": common.DataExportDefaultTime,
 | 
			
		||||
			"default_collapse_sidebar": common.DefaultCollapseSidebar,
 | 
			
		||||
 
 | 
			
		||||
@@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) {
 | 
			
		||||
		"error": err,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayTask(c *gin.Context) {
 | 
			
		||||
	retryTimes := common.RetryTimes
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	relayMode := c.GetInt("relay_mode")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	originalModel := c.GetString("original_model")
 | 
			
		||||
	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 | 
			
		||||
	taskErr := taskRelayHandler(c, relayMode)
 | 
			
		||||
	if taskErr == nil {
 | 
			
		||||
		retryTimes = 0
 | 
			
		||||
	}
 | 
			
		||||
	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
 | 
			
		||||
		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		channelId = channel.Id
 | 
			
		||||
		useChannel := c.GetStringSlice("use_channel")
 | 
			
		||||
		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 | 
			
		||||
		c.Set("use_channel", useChannel)
 | 
			
		||||
		common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
 | 
			
		||||
		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 | 
			
		||||
 | 
			
		||||
		requestBody, err := common.GetRequestBody(c)
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
		taskErr = taskRelayHandler(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	useChannel := c.GetStringSlice("use_channel")
 | 
			
		||||
	if len(useChannel) > 1 {
 | 
			
		||||
		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 | 
			
		||||
		common.LogInfo(c.Request.Context(), retryLogStr)
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr != nil {
 | 
			
		||||
		if taskErr.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(taskErr.StatusCode, taskErr)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
 | 
			
		||||
	var err *dto.TaskError
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
 | 
			
		||||
		err = relay.RelayTaskFetch(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = relay.RelayTaskSubmit(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
 | 
			
		||||
	if taskErr == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if retryTimes <= 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := c.Get("specific_channel_id"); ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode == 307 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode/100 == 5 {
 | 
			
		||||
		// 超时不重试
 | 
			
		||||
		if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode == http.StatusBadRequest {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode == 408 {
 | 
			
		||||
		// azure处理超时不重试
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.LocalError {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if taskErr.StatusCode/100 == 2 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										284
									
								
								controller/task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								controller/task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,284 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/samber/lo"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"one-api/relay"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func UpdateTaskBulk() {
 | 
			
		||||
	//revocer
 | 
			
		||||
	//imageModel := "midjourney"
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Duration(15) * time.Second)
 | 
			
		||||
		common.SysLog("任务进度轮询开始")
 | 
			
		||||
		ctx := context.TODO()
 | 
			
		||||
		allTasks := model.GetAllUnFinishSyncTasks(500)
 | 
			
		||||
		platformTask := make(map[constant.TaskPlatform][]*model.Task)
 | 
			
		||||
		for _, t := range allTasks {
 | 
			
		||||
			platformTask[t.Platform] = append(platformTask[t.Platform], t)
 | 
			
		||||
		}
 | 
			
		||||
		for platform, tasks := range platformTask {
 | 
			
		||||
			if len(tasks) == 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			taskChannelM := make(map[int][]string)
 | 
			
		||||
			taskM := make(map[string]*model.Task)
 | 
			
		||||
			nullTaskIds := make([]int64, 0)
 | 
			
		||||
			for _, task := range tasks {
 | 
			
		||||
				if task.TaskID == "" {
 | 
			
		||||
					// 统计失败的未完成任务
 | 
			
		||||
					nullTaskIds = append(nullTaskIds, task.ID)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				taskM[task.TaskID] = task
 | 
			
		||||
				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
 | 
			
		||||
			}
 | 
			
		||||
			if len(nullTaskIds) > 0 {
 | 
			
		||||
				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
 | 
			
		||||
					"status":   "FAILURE",
 | 
			
		||||
					"progress": "100%",
 | 
			
		||||
				})
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
 | 
			
		||||
				} else {
 | 
			
		||||
					common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if len(taskChannelM) == 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			UpdateTaskByPlatform(platform, taskChannelM, taskM)
 | 
			
		||||
		}
 | 
			
		||||
		common.SysLog("任务进度轮询完成")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
 | 
			
		||||
	switch platform {
 | 
			
		||||
	case constant.TaskPlatformMidjourney:
 | 
			
		||||
		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
 | 
			
		||||
	case constant.TaskPlatformSuno:
 | 
			
		||||
		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
 | 
			
		||||
	default:
 | 
			
		||||
		common.SysLog("未知平台")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
 | 
			
		||||
	for channelId, taskIds := range taskChannelM {
 | 
			
		||||
		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
 | 
			
		||||
	common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
 | 
			
		||||
	if len(taskIds) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	channel, err := model.CacheGetChannel(channelId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
 | 
			
		||||
		err = model.TaskBulkUpdate(taskIds, map[string]any{
 | 
			
		||||
			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
 | 
			
		||||
			"status":      "FAILURE",
 | 
			
		||||
			"progress":    "100%",
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return errors.New("adaptor not found")
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
 | 
			
		||||
		"ids": taskIds,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 | 
			
		||||
		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
 | 
			
		||||
	err = json.Unmarshal(responseBody, &responseItems)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if !responseItems.IsSuccess() {
 | 
			
		||||
		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, responseItem := range responseItems.Data {
 | 
			
		||||
		task := taskM[responseItem.TaskID]
 | 
			
		||||
		if !checkTaskNeedUpdate(task, responseItem) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
 | 
			
		||||
		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
 | 
			
		||||
		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
 | 
			
		||||
		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
 | 
			
		||||
		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
 | 
			
		||||
		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
 | 
			
		||||
			common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
 | 
			
		||||
			task.Progress = "100%"
 | 
			
		||||
			err = model.CacheUpdateUserQuota(task.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				quota := task.Quota
 | 
			
		||||
				if quota != 0 {
 | 
			
		||||
					err = model.IncreaseUserQuota(task.UserId, quota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.LogError(ctx, "fail to increase user quota: "+err.Error())
 | 
			
		||||
					}
 | 
			
		||||
					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
 | 
			
		||||
					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if responseItem.Status == model.TaskStatusSuccess {
 | 
			
		||||
			task.Progress = "100%"
 | 
			
		||||
		}
 | 
			
		||||
		task.Data = responseItem.Data
 | 
			
		||||
 | 
			
		||||
		err = task.Update()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("UpdateMidjourneyTask task error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
 | 
			
		||||
 | 
			
		||||
	if oldTask.SubmitTime != newTask.SubmitTime {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if oldTask.StartTime != newTask.StartTime {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if oldTask.FinishTime != newTask.FinishTime {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if string(oldTask.Status) != newTask.Status {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if oldTask.FailReason != newTask.FailReason {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if oldTask.FinishTime != newTask.FinishTime {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oldData, _ := json.Marshal(oldTask.Data)
 | 
			
		||||
	newData, _ := json.Marshal(newTask.Data)
 | 
			
		||||
 | 
			
		||||
	sort.Slice(oldData, func(i, j int) bool {
 | 
			
		||||
		return oldData[i] < oldData[j]
 | 
			
		||||
	})
 | 
			
		||||
	sort.Slice(newData, func(i, j int) bool {
 | 
			
		||||
		return newData[i] < newData[j]
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if string(oldData) != string(newData) {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllTask(c *gin.Context) {
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
	// 解析其他查询参数
 | 
			
		||||
	queryParams := model.SyncTaskQueryParams{
 | 
			
		||||
		Platform:       constant.TaskPlatform(c.Query("platform")),
 | 
			
		||||
		TaskID:         c.Query("task_id"),
 | 
			
		||||
		Status:         c.Query("status"),
 | 
			
		||||
		Action:         c.Query("action"),
 | 
			
		||||
		StartTimestamp: startTimestamp,
 | 
			
		||||
		EndTimestamp:   endTimestamp,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
 | 
			
		||||
	if logs == nil {
 | 
			
		||||
		logs = make([]*model.Task, 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserTask(c *gin.Context) {
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
 | 
			
		||||
	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
 | 
			
		||||
	queryParams := model.SyncTaskQueryParams{
 | 
			
		||||
		Platform:       constant.TaskPlatform(c.Query("platform")),
 | 
			
		||||
		TaskID:         c.Query("task_id"),
 | 
			
		||||
		Status:         c.Query("status"),
 | 
			
		||||
		Action:         c.Query("action"),
 | 
			
		||||
		StartTimestamp: startTimestamp,
 | 
			
		||||
		EndTimestamp:   endTimestamp,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
 | 
			
		||||
	if logs == nil {
 | 
			
		||||
		logs = make([]*model.Task, 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										129
									
								
								dto/suno.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								dto/suno.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,129 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TaskData interface {
 | 
			
		||||
	SunoDataResponse | []SunoDataResponse | string | any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoSubmitReq struct {
 | 
			
		||||
	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
 | 
			
		||||
	Prompt               string  `json:"prompt,omitempty"`
 | 
			
		||||
	Mv                   string  `json:"mv,omitempty"`
 | 
			
		||||
	Title                string  `json:"title,omitempty"`
 | 
			
		||||
	Tags                 string  `json:"tags,omitempty"`
 | 
			
		||||
	ContinueAt           float64 `json:"continue_at,omitempty"`
 | 
			
		||||
	TaskID               string  `json:"task_id,omitempty"`
 | 
			
		||||
	ContinueClipId       string  `json:"continue_clip_id,omitempty"`
 | 
			
		||||
	MakeInstrumental     bool    `json:"make_instrumental"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type FetchReq struct {
 | 
			
		||||
	IDs []string `json:"ids"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoDataResponse struct {
 | 
			
		||||
	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
 | 
			
		||||
	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
 | 
			
		||||
	Status     string          `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
 | 
			
		||||
	FailReason string          `json:"fail_reason"`
 | 
			
		||||
	SubmitTime int64           `json:"submit_time" gorm:"index"`
 | 
			
		||||
	StartTime  int64           `json:"start_time" gorm:"index"`
 | 
			
		||||
	FinishTime int64           `json:"finish_time" gorm:"index"`
 | 
			
		||||
	Data       json.RawMessage `json:"data" gorm:"type:json"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoSong struct {
 | 
			
		||||
	ID                string       `json:"id"`
 | 
			
		||||
	VideoURL          string       `json:"video_url"`
 | 
			
		||||
	AudioURL          string       `json:"audio_url"`
 | 
			
		||||
	ImageURL          string       `json:"image_url"`
 | 
			
		||||
	ImageLargeURL     string       `json:"image_large_url"`
 | 
			
		||||
	MajorModelVersion string       `json:"major_model_version"`
 | 
			
		||||
	ModelName         string       `json:"model_name"`
 | 
			
		||||
	Status            string       `json:"status"`
 | 
			
		||||
	Title             string       `json:"title"`
 | 
			
		||||
	Text              string       `json:"text"`
 | 
			
		||||
	Metadata          SunoMetadata `json:"metadata"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoMetadata struct {
 | 
			
		||||
	Tags                 string      `json:"tags"`
 | 
			
		||||
	Prompt               string      `json:"prompt"`
 | 
			
		||||
	GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
 | 
			
		||||
	AudioPromptID        interface{} `json:"audio_prompt_id"`
 | 
			
		||||
	Duration             interface{} `json:"duration"`
 | 
			
		||||
	ErrorType            interface{} `json:"error_type"`
 | 
			
		||||
	ErrorMessage         interface{} `json:"error_message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoLyrics struct {
 | 
			
		||||
	ID     string `json:"id"`
 | 
			
		||||
	Status string `json:"status"`
 | 
			
		||||
	Title  string `json:"title"`
 | 
			
		||||
	Text   string `json:"text"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const TaskSuccessCode = "success"
 | 
			
		||||
 | 
			
		||||
type TaskResponse[T TaskData] struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Data    T      `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TaskResponse[T]) IsSuccess() bool {
 | 
			
		||||
	return t.Code == TaskSuccessCode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskDto struct {
 | 
			
		||||
	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
 | 
			
		||||
	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
 | 
			
		||||
	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
 | 
			
		||||
	FailReason string          `json:"fail_reason"`
 | 
			
		||||
	SubmitTime int64           `json:"submit_time"`
 | 
			
		||||
	StartTime  int64           `json:"start_time"`
 | 
			
		||||
	FinishTime int64           `json:"finish_time"`
 | 
			
		||||
	Progress   string          `json:"progress"`
 | 
			
		||||
	Data       json.RawMessage `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoGoAPISubmitReq struct {
 | 
			
		||||
	CustomMode bool `json:"custom_mode"`
 | 
			
		||||
 | 
			
		||||
	Input SunoGoAPISubmitReqInput `json:"input"`
 | 
			
		||||
 | 
			
		||||
	NotifyHook string `json:"notify_hook,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoGoAPISubmitReqInput struct {
 | 
			
		||||
	GptDescriptionPrompt string  `json:"gpt_description_prompt"`
 | 
			
		||||
	Prompt               string  `json:"prompt"`
 | 
			
		||||
	Mv                   string  `json:"mv"`
 | 
			
		||||
	Title                string  `json:"title"`
 | 
			
		||||
	Tags                 string  `json:"tags"`
 | 
			
		||||
	ContinueAt           float64 `json:"continue_at"`
 | 
			
		||||
	TaskID               string  `json:"task_id"`
 | 
			
		||||
	ContinueClipId       string  `json:"continue_clip_id"`
 | 
			
		||||
	MakeInstrumental     bool    `json:"make_instrumental"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GoAPITaskResponse[T any] struct {
 | 
			
		||||
	Code         int    `json:"code"`
 | 
			
		||||
	Message      string `json:"message"`
 | 
			
		||||
	Data         T      `json:"data"`
 | 
			
		||||
	ErrorMessage string `json:"error_message,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GoAPITaskResponseData struct {
 | 
			
		||||
	TaskID string `json:"task_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GoAPIFetchResponseData struct {
 | 
			
		||||
	TaskID string              `json:"task_id"`
 | 
			
		||||
	Status string              `json:"status"`
 | 
			
		||||
	Input  string              `json:"input"`
 | 
			
		||||
	Clips  map[string]SunoSong `json:"clips"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										10
									
								
								dto/task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								dto/task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type TaskError struct {
 | 
			
		||||
	Code       string `json:"code"`
 | 
			
		||||
	Message    string `json:"message"`
 | 
			
		||||
	Data       any    `json:"data"`
 | 
			
		||||
	StatusCode int    `json:"-"`
 | 
			
		||||
	LocalError bool   `json:"-"`
 | 
			
		||||
	Error      error  `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@@ -20,10 +20,10 @@ import (
 | 
			
		||||
	_ "net/http/pprof"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//go:embed web/dist
 | 
			
		||||
// /go:embed web/dist
 | 
			
		||||
var buildFS embed.FS
 | 
			
		||||
 | 
			
		||||
//go:embed web/dist/index.html
 | 
			
		||||
// /go:embed web/dist/index.html
 | 
			
		||||
var indexPage []byte
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
@@ -92,6 +92,9 @@ func main() {
 | 
			
		||||
	common.SafeGoroutine(func() {
 | 
			
		||||
		controller.UpdateMidjourneyTaskBulk()
 | 
			
		||||
	})
 | 
			
		||||
	common.SafeGoroutine(func() {
 | 
			
		||||
		controller.UpdateTaskBulk()
 | 
			
		||||
	})
 | 
			
		||||
	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 | 
			
		||||
		common.BatchUpdateEnabled = true
 | 
			
		||||
		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
			
		||||
 
 | 
			
		||||
@@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 | 
			
		||||
			modelRequest.Model = midjourneyModel
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("relay_mode", relayMode)
 | 
			
		||||
	} else if strings.Contains(c.Request.URL.Path, "/suno/") {
 | 
			
		||||
		relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
 | 
			
		||||
		if relayMode == relayconstant.RelayModeSunoFetch ||
 | 
			
		||||
			relayMode == relayconstant.RelayModeSunoFetchByID {
 | 
			
		||||
			shouldSelectChannel = false
 | 
			
		||||
		} else {
 | 
			
		||||
			modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
 | 
			
		||||
			modelRequest.Model = modelName
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("platform", string(constant.TaskPlatformSuno))
 | 
			
		||||
		c.Set("relay_mode", relayMode)
 | 
			
		||||
	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
		err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -140,6 +140,10 @@ func InitDB() (err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Task{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		common.SysLog("database migrated")
 | 
			
		||||
		err = createRootAccountIfNeed()
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
@@ -41,6 +41,7 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
 | 
			
		||||
	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
 | 
			
		||||
	common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
 | 
			
		||||
	common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
 | 
			
		||||
	common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
 | 
			
		||||
	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
 | 
			
		||||
	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
 | 
			
		||||
@@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			common.DisplayTokenStatEnabled = boolValue
 | 
			
		||||
		case "DrawingEnabled":
 | 
			
		||||
			common.DrawingEnabled = boolValue
 | 
			
		||||
		case "TaskEnabled":
 | 
			
		||||
			common.TaskEnabled = boolValue
 | 
			
		||||
		case "DataExportEnabled":
 | 
			
		||||
			common.DataExportEnabled = boolValue
 | 
			
		||||
		case "DefaultCollapseSidebar":
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										304
									
								
								model/task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										304
									
								
								model/task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,304 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	commonRelay "one-api/relay/common"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TaskStatus string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TaskStatusNotStart   TaskStatus = "NOT_START"
 | 
			
		||||
	TaskStatusSubmitted             = "SUBMITTED"
 | 
			
		||||
	TaskStatusQueued                = "QUEUED"
 | 
			
		||||
	TaskStatusInProgress            = "IN_PROGRESS"
 | 
			
		||||
	TaskStatusFailure               = "FAILURE"
 | 
			
		||||
	TaskStatusSuccess               = "SUCCESS"
 | 
			
		||||
	TaskStatusUnknown               = "UNKNOWN"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Task struct {
 | 
			
		||||
	ID         int64                 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
 | 
			
		||||
	CreatedAt  int64                 `json:"created_at" gorm:"index"`
 | 
			
		||||
	UpdatedAt  int64                 `json:"updated_at"`
 | 
			
		||||
	TaskID     string                `json:"task_id" gorm:"type:varchar(50);index"`  // 第三方id,不一定有/ song id\ Task id
 | 
			
		||||
	Platform   constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
 | 
			
		||||
	UserId     int                   `json:"user_id" gorm:"index"`
 | 
			
		||||
	ChannelId  int                   `json:"channel_id" gorm:"index"`
 | 
			
		||||
	Quota      int                   `json:"quota"`
 | 
			
		||||
	Action     string                `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
 | 
			
		||||
	Status     TaskStatus            `json:"status" gorm:"type:varchar(20);index"` // 任务状态
 | 
			
		||||
	FailReason string                `json:"fail_reason"`
 | 
			
		||||
	SubmitTime int64                 `json:"submit_time" gorm:"index"`
 | 
			
		||||
	StartTime  int64                 `json:"start_time" gorm:"index"`
 | 
			
		||||
	FinishTime int64                 `json:"finish_time" gorm:"index"`
 | 
			
		||||
	Progress   string                `json:"progress" gorm:"type:varchar(20);index"`
 | 
			
		||||
	Properties Properties            `json:"properties" gorm:"type:json"`
 | 
			
		||||
 | 
			
		||||
	Data json.RawMessage `json:"data" gorm:"type:json"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Task) SetData(data any) {
 | 
			
		||||
	b, _ := json.Marshal(data)
 | 
			
		||||
	t.Data = json.RawMessage(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Task) GetData(v any) error {
 | 
			
		||||
	err := json.Unmarshal(t.Data, &v)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Properties struct {
 | 
			
		||||
	Input string `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Properties) Scan(val interface{}) error {
 | 
			
		||||
	bytesValue, _ := val.([]byte)
 | 
			
		||||
	return json.Unmarshal(bytesValue, m)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Properties) Value() (driver.Value, error) {
 | 
			
		||||
	return json.Marshal(m)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
 | 
			
		||||
type SyncTaskQueryParams struct {
 | 
			
		||||
	Platform       constant.TaskPlatform
 | 
			
		||||
	ChannelID      string
 | 
			
		||||
	TaskID         string
 | 
			
		||||
	UserID         string
 | 
			
		||||
	Action         string
 | 
			
		||||
	Status         string
 | 
			
		||||
	StartTimestamp int64
 | 
			
		||||
	EndTimestamp   int64
 | 
			
		||||
	UserIDs        []int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
 | 
			
		||||
	t := &Task{
 | 
			
		||||
		UserId:     relayInfo.UserId,
 | 
			
		||||
		SubmitTime: time.Now().Unix(),
 | 
			
		||||
		Status:     TaskStatusNotStart,
 | 
			
		||||
		Progress:   "0%",
 | 
			
		||||
		ChannelId:  relayInfo.ChannelId,
 | 
			
		||||
		Platform:   platform,
 | 
			
		||||
	}
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
 | 
			
		||||
	var tasks []*Task
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	// 初始化查询构建器
 | 
			
		||||
	query := DB.Where("user_id = ?", userId)
 | 
			
		||||
 | 
			
		||||
	if queryParams.TaskID != "" {
 | 
			
		||||
		query = query.Where("task_id = ?", queryParams.TaskID)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Action != "" {
 | 
			
		||||
		query = query.Where("action = ?", queryParams.Action)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Status != "" {
 | 
			
		||||
		query = query.Where("status = ?", queryParams.Status)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Platform != "" {
 | 
			
		||||
		query = query.Where("platform = ?", queryParams.Platform)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.StartTimestamp != 0 {
 | 
			
		||||
		// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
 | 
			
		||||
		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.EndTimestamp != 0 {
 | 
			
		||||
		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取数据
 | 
			
		||||
	err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tasks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
 | 
			
		||||
	var tasks []*Task
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	// 初始化查询构建器
 | 
			
		||||
	query := DB
 | 
			
		||||
 | 
			
		||||
	// 添加过滤条件
 | 
			
		||||
	if queryParams.ChannelID != "" {
 | 
			
		||||
		query = query.Where("channel_id = ?", queryParams.ChannelID)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Platform != "" {
 | 
			
		||||
		query = query.Where("platform = ?", queryParams.Platform)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.UserID != "" {
 | 
			
		||||
		query = query.Where("user_id = ?", queryParams.UserID)
 | 
			
		||||
	}
 | 
			
		||||
	if len(queryParams.UserIDs) != 0 {
 | 
			
		||||
		query = query.Where("user_id in (?)", queryParams.UserIDs)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.TaskID != "" {
 | 
			
		||||
		query = query.Where("task_id = ?", queryParams.TaskID)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Action != "" {
 | 
			
		||||
		query = query.Where("action = ?", queryParams.Action)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Status != "" {
 | 
			
		||||
		query = query.Where("status = ?", queryParams.Status)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.StartTimestamp != 0 {
 | 
			
		||||
		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.EndTimestamp != 0 {
 | 
			
		||||
		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取数据
 | 
			
		||||
	err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tasks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUnFinishSyncTasks(limit int) []*Task {
 | 
			
		||||
	var tasks []*Task
 | 
			
		||||
	var err error
 | 
			
		||||
	// get all tasks progress is not 100%
 | 
			
		||||
	err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return tasks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
 | 
			
		||||
	if taskId == "" {
 | 
			
		||||
		return nil, false, nil
 | 
			
		||||
	}
 | 
			
		||||
	var task *Task
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Where("task_id = ?", taskId).First(&task).Error
 | 
			
		||||
	exist, err := RecordExist(err)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
	return task, exist, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
 | 
			
		||||
	if taskId == "" {
 | 
			
		||||
		return nil, false, nil
 | 
			
		||||
	}
 | 
			
		||||
	var task *Task
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
 | 
			
		||||
		First(&task).Error
 | 
			
		||||
	exist, err := RecordExist(err)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
	return task, exist, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
 | 
			
		||||
	if len(taskIds) == 0 {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	var task []*Task
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
 | 
			
		||||
		Find(&task).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return task, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskUpdateProgress(id int64, progress string) error {
 | 
			
		||||
	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (Task *Task) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(Task).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (Task *Task) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Save(Task).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
 | 
			
		||||
	if len(TaskIds) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return DB.Model(&Task{}).
 | 
			
		||||
		Where("task_id in (?)", TaskIds).
 | 
			
		||||
		Updates(params).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
 | 
			
		||||
	if len(taskIDs) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return DB.Model(&Task{}).
 | 
			
		||||
		Where("id in (?)", taskIDs).
 | 
			
		||||
		Updates(params).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
 | 
			
		||||
	if len(ids) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return DB.Model(&Task{}).
 | 
			
		||||
		Where("id in (?)", ids).
 | 
			
		||||
		Updates(params).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskQuotaUsage struct {
 | 
			
		||||
	Mode  string  `json:"mode"`
 | 
			
		||||
	Count float64 `json:"count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
 | 
			
		||||
	query := DB.Model(Task{})
 | 
			
		||||
	// 添加过滤条件
 | 
			
		||||
	if queryParams.ChannelID != "" {
 | 
			
		||||
		query = query.Where("channel_id = ?", queryParams.ChannelID)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.UserID != "" {
 | 
			
		||||
		query = query.Where("user_id = ?", queryParams.UserID)
 | 
			
		||||
	}
 | 
			
		||||
	if len(queryParams.UserIDs) != 0 {
 | 
			
		||||
		query = query.Where("user_id in (?)", queryParams.UserIDs)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.TaskID != "" {
 | 
			
		||||
		query = query.Where("task_id = ?", queryParams.TaskID)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Action != "" {
 | 
			
		||||
		query = query.Where("action = ?", queryParams.Action)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.Status != "" {
 | 
			
		||||
		query = query.Where("status = ?", queryParams.Status)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.StartTimestamp != 0 {
 | 
			
		||||
		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	if queryParams.EndTimestamp != 0 {
 | 
			
		||||
		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
 | 
			
		||||
	return stat, err
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -75,3 +77,13 @@ func batchUpdate() {
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("batch update finished")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordExist(err error) (bool, error) {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
	if errors.Is(err, gorm.ErrRecordNotFound) {
 | 
			
		||||
		return false, nil
 | 
			
		||||
	}
 | 
			
		||||
	return false, err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,3 +19,22 @@ type Adaptor interface {
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
	GetChannelName() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskAdaptor interface {
 | 
			
		||||
	Init(info *relaycommon.TaskRelayInfo)
 | 
			
		||||
 | 
			
		||||
	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
 | 
			
		||||
 | 
			
		||||
	BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
 | 
			
		||||
	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
 | 
			
		||||
	BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
 | 
			
		||||
 | 
			
		||||
	DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
 | 
			
		||||
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
	GetChannelName() string
 | 
			
		||||
 | 
			
		||||
	// FetchTask
 | 
			
		||||
	FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 | 
			
		||||
	_ = c.Request.Body.Close()
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	fullRequestURL, err := a.BuildRequestURL(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("new request failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	req.GetBody = func() (io.ReadCloser, error) {
 | 
			
		||||
		return io.NopCloser(requestBody), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = a.BuildRequestHeader(c, req, info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("setup request header failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := doRequest(c, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("do request failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										172
									
								
								relay/channel/task/suno/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								relay/channel/task/suno/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,172 @@
 | 
			
		||||
package suno
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TaskAdaptor struct {
 | 
			
		||||
	ChannelType int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
 | 
			
		||||
	a.ChannelType = info.ChannelType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
 | 
			
		||||
	action := strings.ToUpper(c.Param("action"))
 | 
			
		||||
 | 
			
		||||
	var sunoRequest *dto.SunoSubmitReq
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &sunoRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err = actionValidate(c, sunoRequest, action)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sunoRequest.ContinueClipId != "" {
 | 
			
		||||
		if sunoRequest.TaskID == "" {
 | 
			
		||||
			taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		info.OriginTaskID = sunoRequest.TaskID
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	info.Action = action
 | 
			
		||||
	c.Set("task_request", sunoRequest)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
 | 
			
		||||
	baseURL := info.BaseUrl
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
 | 
			
		||||
	return fullRequestURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
 | 
			
		||||
	sunoRequest, ok := c.Get("task_request")
 | 
			
		||||
	if !ok {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &sunoRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	data, err := json.Marshal(sunoRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return bytes.NewReader(data), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoTaskApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var sunoResponse dto.TaskResponse[string]
 | 
			
		||||
	err = json.Unmarshal(responseBody, &sunoResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !sunoResponse.IsSuccess() {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sunoResponse.Data, nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) GetChannelName() string {
 | 
			
		||||
	return ChannelName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
 | 
			
		||||
	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
 | 
			
		||||
	byteBody, err := json.Marshal(body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("Get Task error: %v", err))
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer req.Body.Close()
 | 
			
		||||
	// 设置超时时间
 | 
			
		||||
	timeout := time.Second * 15
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	// 使用带有超时的 context 创建新的请求
 | 
			
		||||
	req = req.WithContext(ctx)
 | 
			
		||||
	req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+key)
 | 
			
		||||
	resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
 | 
			
		||||
	switch action {
 | 
			
		||||
	case constant.SunoActionMusic:
 | 
			
		||||
		if sunoRequest.Mv == "" {
 | 
			
		||||
			sunoRequest.Mv = "chirp-v3-0"
 | 
			
		||||
		}
 | 
			
		||||
	case constant.SunoActionLyrics:
 | 
			
		||||
		if sunoRequest.Prompt == "" {
 | 
			
		||||
			err = fmt.Errorf("prompt_empty")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		err = fmt.Errorf("invalid_action")
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										7
									
								
								relay/channel/task/suno/models.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/task/suno/models.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
package suno
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"suno_music", "suno_lyrics",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ChannelName = "suno"
 | 
			
		||||
@@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
 | 
			
		||||
func (info *RelayInfo) SetIsStream(isStream bool) {
 | 
			
		||||
	info.IsStream = isStream
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskRelayInfo struct {
 | 
			
		||||
	ChannelType       int
 | 
			
		||||
	ChannelId         int
 | 
			
		||||
	TokenId           int
 | 
			
		||||
	UserId            int
 | 
			
		||||
	Group             string
 | 
			
		||||
	StartTime         time.Time
 | 
			
		||||
	ApiType           int
 | 
			
		||||
	RelayMode         int
 | 
			
		||||
	UpstreamModelName string
 | 
			
		||||
	RequestURLPath    string
 | 
			
		||||
	ApiKey            string
 | 
			
		||||
	BaseUrl           string
 | 
			
		||||
 | 
			
		||||
	Action       string
 | 
			
		||||
	OriginTaskID string
 | 
			
		||||
 | 
			
		||||
	ConsumeQuota bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	startTime := time.Now()
 | 
			
		||||
 | 
			
		||||
	apiType, _ := constant.ChannelType2APIType(channelType)
 | 
			
		||||
 | 
			
		||||
	info := &TaskRelayInfo{
 | 
			
		||||
		RelayMode:      constant.Path2RelayMode(c.Request.URL.Path),
 | 
			
		||||
		BaseUrl:        c.GetString("base_url"),
 | 
			
		||||
		RequestURLPath: c.Request.URL.String(),
 | 
			
		||||
		ChannelType:    channelType,
 | 
			
		||||
		ChannelId:      channelId,
 | 
			
		||||
		TokenId:        tokenId,
 | 
			
		||||
		UserId:         userId,
 | 
			
		||||
		Group:          group,
 | 
			
		||||
		StartTime:      startTime,
 | 
			
		||||
		ApiType:        apiType,
 | 
			
		||||
		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
 | 
			
		||||
	}
 | 
			
		||||
	if info.BaseUrl == "" {
 | 
			
		||||
		info.BaseUrl = common.ChannelBaseURLs[channelType]
 | 
			
		||||
	}
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,9 @@
 | 
			
		||||
package constant
 | 
			
		||||
 | 
			
		||||
import "strings"
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RelayModeUnknown = iota
 | 
			
		||||
@@ -26,6 +29,9 @@ const (
 | 
			
		||||
	RelayModeMidjourneyModal
 | 
			
		||||
	RelayModeMidjourneyShorten
 | 
			
		||||
	RelayModeSwapFace
 | 
			
		||||
	RelayModeSunoFetch
 | 
			
		||||
	RelayModeSunoFetchByID
 | 
			
		||||
	RelayModeSunoSubmit
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Path2RelayMode(path string) int {
 | 
			
		||||
@@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
 | 
			
		||||
	}
 | 
			
		||||
	return relayMode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Path2RelaySuno(method, path string) int {
 | 
			
		||||
	relayMode := RelayModeUnknown
 | 
			
		||||
	if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
 | 
			
		||||
		relayMode = RelayModeSunoFetch
 | 
			
		||||
	} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
 | 
			
		||||
		relayMode = RelayModeSunoFetchByID
 | 
			
		||||
	} else if strings.Contains(path, "/submit/") {
 | 
			
		||||
		relayMode = RelayModeSunoSubmit
 | 
			
		||||
	}
 | 
			
		||||
	return relayMode
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	commonconstant "one-api/constant"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel/ali"
 | 
			
		||||
	"one-api/relay/channel/aws"
 | 
			
		||||
@@ -12,6 +13,7 @@ import (
 | 
			
		||||
	"one-api/relay/channel/openai"
 | 
			
		||||
	"one-api/relay/channel/palm"
 | 
			
		||||
	"one-api/relay/channel/perplexity"
 | 
			
		||||
	"one-api/relay/channel/task/suno"
 | 
			
		||||
	"one-api/relay/channel/tencent"
 | 
			
		||||
	"one-api/relay/channel/xunfei"
 | 
			
		||||
	"one-api/relay/channel/zhipu"
 | 
			
		||||
@@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
 | 
			
		||||
	switch platform {
 | 
			
		||||
	//case constant.APITypeAIProxyLibrary:
 | 
			
		||||
	//	return &aiproxy.Adaptor{}
 | 
			
		||||
	case commonconstant.TaskPlatformSuno:
 | 
			
		||||
		return &suno.TaskAdaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										242
									
								
								relay/relay_task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								relay/relay_task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,242 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
Task 任务通过平台、Action 区分任务
 | 
			
		||||
*/
 | 
			
		||||
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 | 
			
		||||
	platform := constant.TaskPlatform(c.GetString("platform"))
 | 
			
		||||
	relayInfo := relaycommon.GenTaskRelayInfo(c)
 | 
			
		||||
 | 
			
		||||
	adaptor := GetTaskAdaptor(platform)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(relayInfo)
 | 
			
		||||
	// get & validate taskRequest 获取并验证文本请求
 | 
			
		||||
	taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
 | 
			
		||||
	if taskErr != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
 | 
			
		||||
	modelPrice, success := common.GetModelPrice(modelName, true)
 | 
			
		||||
	if !success {
 | 
			
		||||
		defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			modelPrice = 0.1
 | 
			
		||||
		} else {
 | 
			
		||||
			modelPrice = defaultPrice
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 预扣
 | 
			
		||||
	groupRatio := common.GetGroupRatio(relayInfo.Group)
 | 
			
		||||
	ratio := modelPrice * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	quota := int(ratio * common.QuotaPerUnit)
 | 
			
		||||
	if userQuota-quota < 0 {
 | 
			
		||||
		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relayInfo.OriginTaskID != "" {
 | 
			
		||||
		originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !exist {
 | 
			
		||||
			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if originTask.ChannelId != relayInfo.ChannelId {
 | 
			
		||||
			channel, err := model.GetChannelById(originTask.ChannelId, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
 | 
			
		||||
			}
 | 
			
		||||
			c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
			c.Set("channel_id", originTask.ChannelId)
 | 
			
		||||
			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
 | 
			
		||||
			relayInfo.BaseUrl = channel.GetBaseURL()
 | 
			
		||||
			relayInfo.ChannelId = originTask.ChannelId
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// build body
 | 
			
		||||
	requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// do request
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// handle response
 | 
			
		||||
	if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		responseBody, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		// release quota
 | 
			
		||||
		if relayInfo.ConsumeQuota && taskErr == nil {
 | 
			
		||||
			err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(relayInfo.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
 | 
			
		||||
				other := make(map[string]interface{})
 | 
			
		||||
				other["model_price"] = modelPrice
 | 
			
		||||
				other["group_ratio"] = groupRatio
 | 
			
		||||
				model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 | 
			
		||||
				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	if taskErr != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	relayInfo.ConsumeQuota = true
 | 
			
		||||
	// insert task
 | 
			
		||||
	task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
 | 
			
		||||
	task.TaskID = taskID
 | 
			
		||||
	task.Quota = quota
 | 
			
		||||
	task.Data = taskData
 | 
			
		||||
	err = task.Insert()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
 | 
			
		||||
	relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
 | 
			
		||||
	relayconstant.RelayModeSunoFetch:     sunoFetchRespBodyBuilder,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
 | 
			
		||||
	respBuilder, ok := fetchRespBuilders[relayMode]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, taskErr := respBuilder(c)
 | 
			
		||||
	if taskErr != nil {
 | 
			
		||||
		return taskErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	var condition = struct {
 | 
			
		||||
		IDs    []any  `json:"ids"`
 | 
			
		||||
		Action string `json:"action"`
 | 
			
		||||
	}{}
 | 
			
		||||
	err := c.BindJSON(&condition)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var tasks []any
 | 
			
		||||
	if len(condition.IDs) > 0 {
 | 
			
		||||
		taskModels, err := model.GetByTaskIds(userId, condition.IDs)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		for _, task := range taskModels {
 | 
			
		||||
			tasks = append(tasks, TaskModel2Dto(task))
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		tasks = make([]any, 0)
 | 
			
		||||
	}
 | 
			
		||||
	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
 | 
			
		||||
		Code: "success",
 | 
			
		||||
		Data: tasks,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
 | 
			
		||||
	taskId := c.Param("id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
 | 
			
		||||
	originTask, exist, err := model.GetByTaskId(userId, taskId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !exist {
 | 
			
		||||
		taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err = json.Marshal(dto.TaskResponse[any]{
 | 
			
		||||
		Code: "success",
 | 
			
		||||
		Data: TaskModel2Dto(originTask),
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
 | 
			
		||||
	return &dto.TaskDto{
 | 
			
		||||
		TaskID:     task.TaskID,
 | 
			
		||||
		Action:     task.Action,
 | 
			
		||||
		Status:     string(task.Status),
 | 
			
		||||
		FailReason: task.FailReason,
 | 
			
		||||
		SubmitTime: task.SubmitTime,
 | 
			
		||||
		StartTime:  task.StartTime,
 | 
			
		||||
		FinishTime: task.FinishTime,
 | 
			
		||||
		Progress:   task.Progress,
 | 
			
		||||
		Data:       task.Data,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		mjRoute := apiRouter.Group("/mj")
 | 
			
		||||
		mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
 | 
			
		||||
		mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
 | 
			
		||||
 | 
			
		||||
		taskRoute := apiRouter.Group("/task")
 | 
			
		||||
		{
 | 
			
		||||
			taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
 | 
			
		||||
			taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	relayMjModeRouter := router.Group("/:mode/mj")
 | 
			
		||||
	registerMjRouterGroup(relayMjModeRouter)
 | 
			
		||||
	//relayMjRouter.Use()
 | 
			
		||||
 | 
			
		||||
	relaySunoRouter := router.Group("/suno")
 | 
			
		||||
	relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
			
		||||
	{
 | 
			
		||||
		relaySunoRouter.POST("/submit/:action", controller.RelayTask)
 | 
			
		||||
		relaySunoRouter.POST("/fetch", controller.RelayTask)
 | 
			
		||||
		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
 | 
			
		||||
 
 | 
			
		||||
@@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
 | 
			
		||||
		openaiErr.StatusCode = intCode
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
 | 
			
		||||
	openaiErr := TaskErrorWrapper(err, code, statusCode)
 | 
			
		||||
	openaiErr.LocalError = true
 | 
			
		||||
	return openaiErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
 | 
			
		||||
	text := err.Error()
 | 
			
		||||
 | 
			
		||||
	// 定义一个正则表达式匹配URL
 | 
			
		||||
	if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
 | 
			
		||||
		common.SysLog(fmt.Sprintf("error: %s", text))
 | 
			
		||||
		text = "请求上游地址失败"
 | 
			
		||||
	}
 | 
			
		||||
	//避免暴露内部错误
 | 
			
		||||
 | 
			
		||||
	taskError := &dto.TaskError{
 | 
			
		||||
		Code:       code,
 | 
			
		||||
		Message:    text,
 | 
			
		||||
		StatusCode: statusCode,
 | 
			
		||||
		Error:      err,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return taskError
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								service/task.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								service/task.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
 | 
			
		||||
	return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
 | 
			
		||||
}
 | 
			
		||||
@@ -23,6 +23,7 @@ import Chat from './pages/Chat';
 | 
			
		||||
import { Layout } from '@douyinfe/semi-ui';
 | 
			
		||||
import Midjourney from './pages/Midjourney';
 | 
			
		||||
import Pricing from './pages/Pricing/index.js';
 | 
			
		||||
import Task from "./pages/Task/index.js";
 | 
			
		||||
// import Detail from './pages/Detail';
 | 
			
		||||
 | 
			
		||||
const Home = lazy(() => import('./pages/Home'));
 | 
			
		||||
@@ -220,6 +221,16 @@ function App() {
 | 
			
		||||
              </PrivateRoute>
 | 
			
		||||
            }
 | 
			
		||||
          />
 | 
			
		||||
          <Route
 | 
			
		||||
            path='/task'
 | 
			
		||||
            element={
 | 
			
		||||
                <PrivateRoute>
 | 
			
		||||
                    <Suspense fallback={<Loading></Loading>}>
 | 
			
		||||
                        <Task />
 | 
			
		||||
                    </Suspense>
 | 
			
		||||
                </PrivateRoute>
 | 
			
		||||
            }
 | 
			
		||||
          />
 | 
			
		||||
          <Route
 | 
			
		||||
            path='/pricing'
 | 
			
		||||
            element={
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ import {
 | 
			
		||||
import '../index.css';
 | 
			
		||||
 | 
			
		||||
import {
 | 
			
		||||
  IconCalendarClock,
 | 
			
		||||
  IconCalendarClock, IconChecklistStroked,
 | 
			
		||||
  IconComment,
 | 
			
		||||
  IconCreditCard,
 | 
			
		||||
  IconGift,
 | 
			
		||||
@@ -58,6 +58,7 @@ const SiderBar = () => {
 | 
			
		||||
    chat: '/chat',
 | 
			
		||||
    detail: '/detail',
 | 
			
		||||
    pricing: '/pricing',
 | 
			
		||||
    task: '/task',
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const headerButtons = useMemo(
 | 
			
		||||
@@ -142,6 +143,16 @@ const SiderBar = () => {
 | 
			
		||||
            ? 'semi-navigation-item-normal'
 | 
			
		||||
            : 'tableHiddle',
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        text: '异步任务',
 | 
			
		||||
        itemKey: 'task',
 | 
			
		||||
        to: '/task',
 | 
			
		||||
        icon: <IconChecklistStroked />,
 | 
			
		||||
        className:
 | 
			
		||||
            localStorage.getItem('enable_task') === 'true'
 | 
			
		||||
                ? 'semi-navigation-item-normal'
 | 
			
		||||
                : 'tableHiddle',
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        text: '设置',
 | 
			
		||||
        itemKey: 'setting',
 | 
			
		||||
@@ -158,6 +169,7 @@ const SiderBar = () => {
 | 
			
		||||
    [
 | 
			
		||||
      localStorage.getItem('enable_data_export'),
 | 
			
		||||
      localStorage.getItem('enable_drawing'),
 | 
			
		||||
      localStorage.getItem('enable_task'),
 | 
			
		||||
      localStorage.getItem('chat_link'),
 | 
			
		||||
      isAdmin(),
 | 
			
		||||
    ],
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										400
									
								
								web/src/components/TaskLogsTable.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										400
									
								
								web/src/components/TaskLogsTable.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,400 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Label } from 'semantic-ui-react';
 | 
			
		||||
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import {
 | 
			
		||||
    Table,
 | 
			
		||||
    Tag,
 | 
			
		||||
    Form,
 | 
			
		||||
    Button,
 | 
			
		||||
    Layout,
 | 
			
		||||
    Modal,
 | 
			
		||||
    Typography, Progress, Card
 | 
			
		||||
} from '@douyinfe/semi-ui';
 | 
			
		||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
 | 
			
		||||
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
 | 
			
		||||
    'light-blue', 'lime', 'orange', 'pink',
 | 
			
		||||
    'purple', 'red', 'teal', 'violet', 'yellow'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
const renderTimestamp = (timestampInSeconds) => {
 | 
			
		||||
    const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
 | 
			
		||||
 | 
			
		||||
    const year = date.getFullYear(); // 获取年份
 | 
			
		||||
    const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
 | 
			
		||||
    const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
 | 
			
		||||
    const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
 | 
			
		||||
    const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
 | 
			
		||||
    const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
 | 
			
		||||
 | 
			
		||||
    return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
function renderDuration(submit_time, finishTime) {
 | 
			
		||||
    // 确保startTime和finishTime都是有效的时间戳
 | 
			
		||||
    if (!submit_time || !finishTime) return 'N/A';
 | 
			
		||||
 | 
			
		||||
    // 将时间戳转换为Date对象
 | 
			
		||||
    const start = new Date(submit_time);
 | 
			
		||||
    const finish = new Date(finishTime);
 | 
			
		||||
 | 
			
		||||
    // 计算时间差(毫秒)
 | 
			
		||||
    const durationMs = finish - start;
 | 
			
		||||
 | 
			
		||||
    // 将时间差转换为秒,并保留一位小数
 | 
			
		||||
    const durationSec = (durationMs / 1000).toFixed(1);
 | 
			
		||||
 | 
			
		||||
    // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
 | 
			
		||||
    const color = durationSec > 60 ? 'red' : 'green';
 | 
			
		||||
 | 
			
		||||
    // 返回带有样式的颜色标签
 | 
			
		||||
    return (
 | 
			
		||||
        <Tag color={color} size="large">
 | 
			
		||||
            {durationSec} 秒
 | 
			
		||||
        </Tag>
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const LogsTable = () => {
 | 
			
		||||
    const [isModalOpen, setIsModalOpen] = useState(false);
 | 
			
		||||
    const [modalContent, setModalContent] = useState('');
 | 
			
		||||
    const isAdminUser = isAdmin();
 | 
			
		||||
    const columns = [
 | 
			
		||||
        {
 | 
			
		||||
            title: "提交时间",
 | 
			
		||||
            dataIndex: 'submit_time',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {text ? renderTimestamp(text) : "-"}
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: "结束时间",
 | 
			
		||||
            dataIndex: 'finish_time',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {text ? renderTimestamp(text) : "-"}
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: '进度',
 | 
			
		||||
            dataIndex: 'progress',
 | 
			
		||||
            width: 50,
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {
 | 
			
		||||
                            // 转换例如100%为数字100,如果text未定义,返回0
 | 
			
		||||
                            isNaN(text.replace('%', '')) ? text : <Progress width={42} type="circle" showInfo={true} percent={Number(text.replace('%', '') || 0)} aria-label="drawing progress" />
 | 
			
		||||
                        }
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: '花费时间',
 | 
			
		||||
            dataIndex: 'finish_time', // 以finish_time作为dataIndex
 | 
			
		||||
            key: 'finish_time',
 | 
			
		||||
            render: (finish, record) => {
 | 
			
		||||
                // 假设record.start_time是存在的,并且finish是完成时间的时间戳
 | 
			
		||||
                return <>
 | 
			
		||||
                    {
 | 
			
		||||
                        finish ? renderDuration(record.submit_time, finish) : "-"
 | 
			
		||||
                    }
 | 
			
		||||
                </>
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: "渠道",
 | 
			
		||||
            dataIndex: 'channel_id',
 | 
			
		||||
            className: isAdminUser ? 'tableShow' : 'tableHiddle',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        <Tag
 | 
			
		||||
                            color={colors[parseInt(text) % colors.length]}
 | 
			
		||||
                            size='large'
 | 
			
		||||
                            onClick={() => {
 | 
			
		||||
                                copyText(text); // 假设copyText是用于文本复制的函数
 | 
			
		||||
                            }}
 | 
			
		||||
                        >
 | 
			
		||||
                            {' '}
 | 
			
		||||
                            {text}{' '}
 | 
			
		||||
                        </Tag>
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: "平台",
 | 
			
		||||
            dataIndex: 'platform',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {renderPlatform(text)}
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: '类型',
 | 
			
		||||
            dataIndex: 'action',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {renderType(text)}
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: '任务ID(点击查看详情)',
 | 
			
		||||
            dataIndex: 'task_id',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (<Typography.Text
 | 
			
		||||
                    ellipsis={{ showTooltip: true }}
 | 
			
		||||
                    //style={{width: 100}}
 | 
			
		||||
                    onClick={() => {
 | 
			
		||||
                        setModalContent(JSON.stringify(record, null, 2));
 | 
			
		||||
                        setIsModalOpen(true);
 | 
			
		||||
                    }}
 | 
			
		||||
                >
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {text}
 | 
			
		||||
                    </div>
 | 
			
		||||
                </Typography.Text>);
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            title: '任务状态',
 | 
			
		||||
            dataIndex: 'status',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                return (
 | 
			
		||||
                    <div>
 | 
			
		||||
                        {renderStatus(text)}
 | 
			
		||||
                    </div>
 | 
			
		||||
                );
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            title: '失败原因',
 | 
			
		||||
            dataIndex: 'fail_reason',
 | 
			
		||||
            render: (text, record, index) => {
 | 
			
		||||
                // 如果text未定义,返回替代文本,例如空字符串''或其他
 | 
			
		||||
                if (!text) {
 | 
			
		||||
                    return '无';
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                return (
 | 
			
		||||
                    <Typography.Text
 | 
			
		||||
                        ellipsis={{ showTooltip: true }}
 | 
			
		||||
                        style={{ width: 100 }}
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                            setModalContent(text);
 | 
			
		||||
                            setIsModalOpen(true);
 | 
			
		||||
                        }}
 | 
			
		||||
                    >
 | 
			
		||||
                        {text}
 | 
			
		||||
                    </Typography.Text>
 | 
			
		||||
                );
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    ];
 | 
			
		||||
 | 
			
		||||
    const [logs, setLogs] = useState([]);
 | 
			
		||||
    const [loading, setLoading] = useState(true);
 | 
			
		||||
    const [activePage, setActivePage] = useState(1);
 | 
			
		||||
    const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
 | 
			
		||||
    const [logType] = useState(0);
 | 
			
		||||
 | 
			
		||||
    let now = new Date();
 | 
			
		||||
    // 初始化start_timestamp为前一天
 | 
			
		||||
    let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate());
 | 
			
		||||
    const [inputs, setInputs] = useState({
 | 
			
		||||
        channel_id: '',
 | 
			
		||||
        task_id: '',
 | 
			
		||||
        start_timestamp: timestamp2string(zeroNow.getTime() /1000),
 | 
			
		||||
        end_timestamp: '',
 | 
			
		||||
    });
 | 
			
		||||
    const { channel_id, task_id, start_timestamp, end_timestamp } = inputs;
 | 
			
		||||
 | 
			
		||||
    const handleInputChange = (value, name) => {
 | 
			
		||||
        setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    const setLogsFormat = (logs) => {
 | 
			
		||||
        for (let i = 0; i < logs.length; i++) {
 | 
			
		||||
            logs[i].timestamp2string = timestamp2string(logs[i].created_at);
 | 
			
		||||
            logs[i].key = '' + logs[i].id;
 | 
			
		||||
        }
 | 
			
		||||
        // data.key = '' + data.id
 | 
			
		||||
        setLogs(logs);
 | 
			
		||||
        setLogCount(logs.length + ITEMS_PER_PAGE);
 | 
			
		||||
        // console.log(logCount);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const loadLogs = async (startIdx) => {
 | 
			
		||||
        setLoading(true);
 | 
			
		||||
 | 
			
		||||
        let url = '';
 | 
			
		||||
        let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000);
 | 
			
		||||
        let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 );
 | 
			
		||||
        if (isAdminUser) {
 | 
			
		||||
            url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
        } else {
 | 
			
		||||
            url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
        }
 | 
			
		||||
        const res = await API.get(url);
 | 
			
		||||
        let { success, message, data } = res.data;
 | 
			
		||||
        if (success) {
 | 
			
		||||
            if (startIdx === 0) {
 | 
			
		||||
                setLogsFormat(data);
 | 
			
		||||
            } else {
 | 
			
		||||
                let newLogs = [...logs];
 | 
			
		||||
                newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
 | 
			
		||||
                setLogsFormat(newLogs);
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            showError(message);
 | 
			
		||||
        }
 | 
			
		||||
        setLoading(false);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
 | 
			
		||||
 | 
			
		||||
    const handlePageChange = page => {
 | 
			
		||||
        setActivePage(page);
 | 
			
		||||
        if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
 | 
			
		||||
            // In this case we have to load more data and then append them.
 | 
			
		||||
            loadLogs(page - 1).then(r => {
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    const refresh = async () => {
 | 
			
		||||
        // setLoading(true);
 | 
			
		||||
        setActivePage(1);
 | 
			
		||||
        await loadLogs(0);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    const copyText = async (text) => {
 | 
			
		||||
        if (await copy(text)) {
 | 
			
		||||
            showSuccess('已复制:' + text);
 | 
			
		||||
        } else {
 | 
			
		||||
            // setSearchKeyword(text);
 | 
			
		||||
            Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text });
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    useEffect(() => {
 | 
			
		||||
        refresh().then();
 | 
			
		||||
    }, [logType]);
 | 
			
		||||
 | 
			
		||||
    const renderType = (type) => {
 | 
			
		||||
        switch (type) {
 | 
			
		||||
            case 'MUSIC':
 | 
			
		||||
                return <Label basic color='grey'> 生成音乐 </Label>;
 | 
			
		||||
            case 'LYRICS':
 | 
			
		||||
                return <Label basic color='pink'> 生成歌词 </Label>;
 | 
			
		||||
 | 
			
		||||
            default:
 | 
			
		||||
                return <Label basic color='black'> 未知 </Label>;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const renderPlatform = (type) => {
 | 
			
		||||
        switch (type) {
 | 
			
		||||
            case "suno":
 | 
			
		||||
                return <Label basic color='green'> Suno </Label>;
 | 
			
		||||
            default:
 | 
			
		||||
                return <Label basic color='black'> 未知 </Label>;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const renderStatus = (type) => {
 | 
			
		||||
        switch (type) {
 | 
			
		||||
            case 'SUCCESS':
 | 
			
		||||
                return <Label basic color='green'> 成功 </Label>;
 | 
			
		||||
            case 'NOT_START':
 | 
			
		||||
                return <Label basic color='black'> 未启动 </Label>;
 | 
			
		||||
            case 'SUBMITTED':
 | 
			
		||||
                return <Label basic color='yellow'> 队列中 </Label>;
 | 
			
		||||
            case 'IN_PROGRESS':
 | 
			
		||||
                return <Label basic color='blue'> 执行中 </Label>;
 | 
			
		||||
            case 'FAILURE':
 | 
			
		||||
                return <Label basic color='red'> 失败 </Label>;
 | 
			
		||||
            case 'QUEUED':
 | 
			
		||||
                return <Label basic color='red'> 排队中 </Label>;
 | 
			
		||||
            case 'UNKNOWN':
 | 
			
		||||
                return <Label basic color='red'> 未知 </Label>;
 | 
			
		||||
            case '':
 | 
			
		||||
                return <Label basic color='black'> 正在提交 </Label>;
 | 
			
		||||
            default:
 | 
			
		||||
                return <Label basic color='black'> 未知 </Label>;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return (
 | 
			
		||||
        <>
 | 
			
		||||
 | 
			
		||||
            <Layout>
 | 
			
		||||
                <Form layout='horizontal' labelPosition='inset'>
 | 
			
		||||
                    <>
 | 
			
		||||
                        {isAdminUser && <Form.Input field="channel_id" label='渠道 ID' style={{ width: '236px', marginBottom: '10px' }} value={channel_id}
 | 
			
		||||
                                                    placeholder={'可选值'} name='channel_id'
 | 
			
		||||
                                                    onChange={value => handleInputChange(value, 'channel_id')} />
 | 
			
		||||
                        }
 | 
			
		||||
                        <Form.Input field="task_id" label={"任务 ID"} style={{ width: '236px', marginBottom: '10px' }} value={task_id}
 | 
			
		||||
                            placeholder={"可选值"}
 | 
			
		||||
                            name='task_id'
 | 
			
		||||
                            onChange={value => handleInputChange(value, 'task_id')} />
 | 
			
		||||
 | 
			
		||||
                        <Form.DatePicker field="start_timestamp" label={"起始时间"} style={{ width: '236px', marginBottom: '10px' }}
 | 
			
		||||
                            initValue={start_timestamp}
 | 
			
		||||
                            value={start_timestamp} type='dateTime'
 | 
			
		||||
                            name='start_timestamp'
 | 
			
		||||
                            onChange={value => handleInputChange(value, 'start_timestamp')} />
 | 
			
		||||
                        <Form.DatePicker field="end_timestamp" fluid label={"结束时间"} style={{ width: '236px', marginBottom: '10px' }}
 | 
			
		||||
                            initValue={end_timestamp}
 | 
			
		||||
                            value={end_timestamp} type='dateTime'
 | 
			
		||||
                            name='end_timestamp'
 | 
			
		||||
                            onChange={value => handleInputChange(value, 'end_timestamp')} />
 | 
			
		||||
                        <Button label={"查询"} type="primary" htmlType="submit" className="btn-margin-right"
 | 
			
		||||
                            onClick={refresh}>查询</Button>
 | 
			
		||||
                    </>
 | 
			
		||||
                </Form>
 | 
			
		||||
                <Card>
 | 
			
		||||
                    <Table columns={columns} dataSource={pageData} pagination={{
 | 
			
		||||
                        currentPage: activePage,
 | 
			
		||||
                        pageSize: ITEMS_PER_PAGE,
 | 
			
		||||
                        total: logCount,
 | 
			
		||||
                        pageSizeOpts: [10, 20, 50, 100],
 | 
			
		||||
                        onPageChange: handlePageChange,
 | 
			
		||||
                    }} loading={loading} />
 | 
			
		||||
                </Card>
 | 
			
		||||
                <Modal
 | 
			
		||||
                    visible={isModalOpen}
 | 
			
		||||
                    onOk={() => setIsModalOpen(false)}
 | 
			
		||||
                    onCancel={() => setIsModalOpen(false)}
 | 
			
		||||
                    closable={null}
 | 
			
		||||
                    bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
 | 
			
		||||
                    width={800} // 设置模态框宽度
 | 
			
		||||
                >
 | 
			
		||||
                    <p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
 | 
			
		||||
                </Modal>
 | 
			
		||||
            </Layout>
 | 
			
		||||
        </>
 | 
			
		||||
    );
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export default LogsTable;
 | 
			
		||||
@@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
    color: 'blue',
 | 
			
		||||
    label: 'Midjourney Proxy Plus',
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    key: 36,
 | 
			
		||||
    text: 'Suno API',
 | 
			
		||||
    value: 36,
 | 
			
		||||
    color: 'purple',
 | 
			
		||||
    label: 'Suno API',
 | 
			
		||||
  },
 | 
			
		||||
  { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
 | 
			
		||||
  {
 | 
			
		||||
    key: 14,
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ export function setStatusData(data) {
 | 
			
		||||
  localStorage.setItem('quota_per_unit', data.quota_per_unit);
 | 
			
		||||
  localStorage.setItem('display_in_currency', data.display_in_currency);
 | 
			
		||||
  localStorage.setItem('enable_drawing', data.enable_drawing);
 | 
			
		||||
  localStorage.setItem('enable_task', data.enable_task);
 | 
			
		||||
  localStorage.setItem('enable_data_export', data.enable_data_export);
 | 
			
		||||
  localStorage.setItem(
 | 
			
		||||
    'data_export_default_time',
 | 
			
		||||
 
 | 
			
		||||
@@ -126,6 +126,12 @@ const EditChannel = (props) => {
 | 
			
		||||
            'mj_uploads',
 | 
			
		||||
          ];
 | 
			
		||||
          break;
 | 
			
		||||
        case 36:
 | 
			
		||||
          localModels = [
 | 
			
		||||
            'suno_music',
 | 
			
		||||
            'suno_lyrics',
 | 
			
		||||
          ];
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          localModels = getChannelModels(value);
 | 
			
		||||
          break;
 | 
			
		||||
@@ -513,12 +519,32 @@ const EditChannel = (props) => {
 | 
			
		||||
              />
 | 
			
		||||
            </>
 | 
			
		||||
          )}
 | 
			
		||||
          <div style={{ marginTop: 10 }}>
 | 
			
		||||
          {inputs.type === 36 && (
 | 
			
		||||
              <>
 | 
			
		||||
                <div style={{marginTop: 10}}>
 | 
			
		||||
                  <Typography.Text strong>
 | 
			
		||||
                    注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用
 | 
			
		||||
                  </Typography.Text>
 | 
			
		||||
                </div>
 | 
			
		||||
                <Input
 | 
			
		||||
                    name='base_url'
 | 
			
		||||
                    placeholder={
 | 
			
		||||
                      '请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com '
 | 
			
		||||
                    }
 | 
			
		||||
                    onChange={(value) => {
 | 
			
		||||
                      handleInputChange('base_url', value);
 | 
			
		||||
                    }}
 | 
			
		||||
                    value={inputs.base_url}
 | 
			
		||||
                    autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </>
 | 
			
		||||
          )}
 | 
			
		||||
          <div style={{marginTop: 10}}>
 | 
			
		||||
            <Typography.Text strong>名称:</Typography.Text>
 | 
			
		||||
          </div>
 | 
			
		||||
          <Input
 | 
			
		||||
            required
 | 
			
		||||
            name='name'
 | 
			
		||||
              required
 | 
			
		||||
              name='name'
 | 
			
		||||
            placeholder={'请为渠道命名'}
 | 
			
		||||
            onChange={(value) => {
 | 
			
		||||
              handleInputChange('name', value);
 | 
			
		||||
@@ -758,7 +784,7 @@ const EditChannel = (props) => {
 | 
			
		||||
              </Space>
 | 
			
		||||
            </div>
 | 
			
		||||
          )}
 | 
			
		||||
          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
			
		||||
          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
 | 
			
		||||
            <>
 | 
			
		||||
              <div style={{ marginTop: 10 }}>
 | 
			
		||||
                <Typography.Text strong>代理:</Typography.Text>
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								web/src/pages/Task/index.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								web/src/pages/Task/index.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
import React from 'react';
 | 
			
		||||
import TaskLogsTable from "../../components/TaskLogsTable.js";
 | 
			
		||||
 | 
			
		||||
const Task = () => (
 | 
			
		||||
  <>
 | 
			
		||||
    <TaskLogsTable />
 | 
			
		||||
  </>
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
export default Task;
 | 
			
		||||
		Reference in New Issue
	
	Block a user