mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
commit
23bfc4f655
@ -47,6 +47,11 @@
|
|||||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||||
|
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
|
||||||
|
+ [x] /suno/submit/music
|
||||||
|
+ [x] /suno/submit/lyrics
|
||||||
|
+ [x] /suno/fetch
|
||||||
|
+ [x] /suno/fetch/:id
|
||||||
|
|
||||||
## 模型支持
|
## 模型支持
|
||||||
此版本额外支持以下模型:
|
此版本额外支持以下模型:
|
||||||
@ -57,6 +62,7 @@
|
|||||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||||
7. 自定义渠道,支持填入完整调用地址
|
7. 自定义渠道,支持填入完整调用地址
|
||||||
|
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
@ -105,6 +111,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
|||||||
## Midjourney接口设置文档
|
## Midjourney接口设置文档
|
||||||
[对接文档](Midjourney.md)
|
[对接文档](Midjourney.md)
|
||||||
|
|
||||||
|
## Suno接口设置文档
|
||||||
|
[对接文档](Suno.md)
|
||||||
|
|
||||||
## 交流群
|
## 交流群
|
||||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
||||||
|
|
||||||
|
37
Suno.md
Normal file
37
Suno.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Suno API文档
|
||||||
|
|
||||||
|
**简介**:Suno API文档
|
||||||
|
|
||||||
|
## 模型列表
|
||||||
|
|
||||||
|
### Suno API支持
|
||||||
|
|
||||||
|
- suno_music (自定义模式、灵感模式、续写)
|
||||||
|
- suno_lyrics (生成歌词)
|
||||||
|
|
||||||
|
|
||||||
|
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"suno_music": 0.3,
|
||||||
|
"suno_lyrics": 0.01
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 渠道设置
|
||||||
|
|
||||||
|
### 对接 Suno API
|
||||||
|
|
||||||
|
1.
|
||||||
|
部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
|
||||||
|
|
||||||
|
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
|
||||||
|
,模型请参考上方模型列表
|
||||||
|
3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
|
||||||
|
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
|
||||||
|
|
||||||
|
### 对接上游new api
|
||||||
|
|
||||||
|
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
|
||||||
|
2. **代理**填写上游new api的地址,例如:http://localhost:3000
|
||||||
|
3. 密钥填写上游new api的密钥
|
@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
|||||||
var DisplayInCurrencyEnabled = true
|
var DisplayInCurrencyEnabled = true
|
||||||
var DisplayTokenStatEnabled = true
|
var DisplayTokenStatEnabled = true
|
||||||
var DrawingEnabled = true
|
var DrawingEnabled = true
|
||||||
|
var TaskEnabled = true
|
||||||
var DataExportEnabled = true
|
var DataExportEnabled = true
|
||||||
var DataExportInterval = 5 // unit: minute
|
var DataExportInterval = 5 // unit: minute
|
||||||
var DataExportDefaultTime = "hour" // unit: minute
|
var DataExportDefaultTime = "hour" // unit: minute
|
||||||
@ -208,8 +209,10 @@ const (
|
|||||||
ChannelTypeAws = 33
|
ChannelTypeAws = 33
|
||||||
ChannelTypeCohere = 34
|
ChannelTypeCohere = 34
|
||||||
ChannelTypeMiniMax = 35
|
ChannelTypeMiniMax = 35
|
||||||
|
ChannelTypeSunoAPI = 36
|
||||||
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@ -249,4 +252,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //33
|
"", //33
|
||||||
"https://api.cohere.ai", //34
|
"https://api.cohere.ai", //34
|
||||||
"https://api.minimax.chat", //35
|
"https://api.minimax.chat", //35
|
||||||
|
"", //36
|
||||||
}
|
}
|
||||||
|
18
constant/task.go
Normal file
18
constant/task.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type TaskPlatform string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
|
TaskPlatformMidjourney = "mj"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SunoActionMusic = "MUSIC"
|
||||||
|
SunoActionLyrics = "LYRICS"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SunoModel2Action = map[string]string{
|
||||||
|
"suno_music": SunoActionMusic,
|
||||||
|
"suno_lyrics": SunoActionLyrics,
|
||||||
|
}
|
@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == common.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
|
if channel.Type == common.ChannelTypeSunoAPI {
|
||||||
|
return errors.New("suno channel test is not supported"), nil
|
||||||
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
|
@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) {
|
|||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
"enable_batch_update": common.BatchUpdateEnabled,
|
"enable_batch_update": common.BatchUpdateEnabled,
|
||||||
"enable_drawing": common.DrawingEnabled,
|
"enable_drawing": common.DrawingEnabled,
|
||||||
|
"enable_task": common.TaskEnabled,
|
||||||
"enable_data_export": common.DataExportEnabled,
|
"enable_data_export": common.DataExportEnabled,
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
|
@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
"error": err,
|
"error": err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RelayTask(c *gin.Context) {
|
||||||
|
retryTimes := common.RetryTimes
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
relayMode := c.GetInt("relay_mode")
|
||||||
|
group := c.GetString("group")
|
||||||
|
originalModel := c.GetString("original_model")
|
||||||
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||||
|
taskErr := taskRelayHandler(c, relayMode)
|
||||||
|
if taskErr == nil {
|
||||||
|
retryTimes = 0
|
||||||
|
}
|
||||||
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
channelId = channel.Id
|
||||||
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
|
c.Set("use_channel", useChannel)
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
|
requestBody, err := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
taskErr = taskRelayHandler(c, relayMode)
|
||||||
|
}
|
||||||
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
|
if len(useChannel) > 1 {
|
||||||
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
|
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||||
|
}
|
||||||
|
if taskErr != nil {
|
||||||
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
|
}
|
||||||
|
c.JSON(taskErr.StatusCode, taskErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||||
|
var err *dto.TaskError
|
||||||
|
switch relayMode {
|
||||||
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||||
|
err = relay.RelayTaskFetch(c, relayMode)
|
||||||
|
default:
|
||||||
|
err = relay.RelayTaskSubmit(c, relayMode)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
||||||
|
if taskErr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if retryTimes <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := c.Get("specific_channel_id"); ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode == 307 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode/100 == 5 {
|
||||||
|
// 超时不重试
|
||||||
|
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode == http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode == 408 {
|
||||||
|
// azure处理超时不重试
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if taskErr.LocalError {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if taskErr.StatusCode/100 == 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
284
controller/task.go
Normal file
284
controller/task.go
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdateTaskBulk() {
|
||||||
|
//revocer
|
||||||
|
//imageModel := "midjourney"
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(15) * time.Second)
|
||||||
|
common.SysLog("任务进度轮询开始")
|
||||||
|
ctx := context.TODO()
|
||||||
|
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||||
|
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||||
|
for _, t := range allTasks {
|
||||||
|
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||||
|
}
|
||||||
|
for platform, tasks := range platformTask {
|
||||||
|
if len(tasks) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskChannelM := make(map[int][]string)
|
||||||
|
taskM := make(map[string]*model.Task)
|
||||||
|
nullTaskIds := make([]int64, 0)
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task.TaskID == "" {
|
||||||
|
// 统计失败的未完成任务
|
||||||
|
nullTaskIds = append(nullTaskIds, task.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskM[task.TaskID] = task
|
||||||
|
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
|
||||||
|
}
|
||||||
|
if len(nullTaskIds) > 0 {
|
||||||
|
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||||
|
} else {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(taskChannelM) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||||
|
}
|
||||||
|
common.SysLog("任务进度轮询完成")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
||||||
|
switch platform {
|
||||||
|
case constant.TaskPlatformMidjourney:
|
||||||
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
|
case constant.TaskPlatformSuno:
|
||||||
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
|
default:
|
||||||
|
common.SysLog("未知平台")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
channel, err := model.CacheGetChannel(channelId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||||
|
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
|
||||||
|
if adaptor == nil {
|
||||||
|
return errors.New("adaptor not found")
|
||||||
|
}
|
||||||
|
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||||
|
"ids": taskIds,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
|
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||||
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !responseItems.IsSuccess() {
|
||||||
|
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, responseItem := range responseItems.Data {
|
||||||
|
task := taskM[responseItem.TaskID]
|
||||||
|
if !checkTaskNeedUpdate(task, responseItem) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
||||||
|
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
||||||
|
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
||||||
|
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||||
|
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||||
|
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||||
|
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||||
|
task.Progress = "100%"
|
||||||
|
err = model.CacheUpdateUserQuota(task.UserId)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
} else {
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if responseItem.Status == model.TaskStatusSuccess {
|
||||||
|
task.Progress = "100%"
|
||||||
|
}
|
||||||
|
task.Data = responseItem.Data
|
||||||
|
|
||||||
|
err = task.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
||||||
|
|
||||||
|
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.StartTime != newTask.StartTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FinishTime != newTask.FinishTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if string(oldTask.Status) != newTask.Status {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FailReason != newTask.FailReason {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if oldTask.FinishTime != newTask.FinishTime {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
oldData, _ := json.Marshal(oldTask.Data)
|
||||||
|
newData, _ := json.Marshal(newTask.Data)
|
||||||
|
|
||||||
|
sort.Slice(oldData, func(i, j int) bool {
|
||||||
|
return oldData[i] < oldData[j]
|
||||||
|
})
|
||||||
|
sort.Slice(newData, func(i, j int) bool {
|
||||||
|
return newData[i] < newData[j]
|
||||||
|
})
|
||||||
|
|
||||||
|
if string(oldData) != string(newData) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllTask(c *gin.Context) {
|
||||||
|
p, _ := strconv.Atoi(c.Query("p"))
|
||||||
|
if p < 0 {
|
||||||
|
p = 0
|
||||||
|
}
|
||||||
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
|
// 解析其他查询参数
|
||||||
|
queryParams := model.SyncTaskQueryParams{
|
||||||
|
Platform: constant.TaskPlatform(c.Query("platform")),
|
||||||
|
TaskID: c.Query("task_id"),
|
||||||
|
Status: c.Query("status"),
|
||||||
|
Action: c.Query("action"),
|
||||||
|
StartTimestamp: startTimestamp,
|
||||||
|
EndTimestamp: endTimestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||||
|
if logs == nil {
|
||||||
|
logs = make([]*model.Task, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": logs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserTask(c *gin.Context) {
|
||||||
|
p, _ := strconv.Atoi(c.Query("p"))
|
||||||
|
if p < 0 {
|
||||||
|
p = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
|
|
||||||
|
queryParams := model.SyncTaskQueryParams{
|
||||||
|
Platform: constant.TaskPlatform(c.Query("platform")),
|
||||||
|
TaskID: c.Query("task_id"),
|
||||||
|
Status: c.Query("status"),
|
||||||
|
Action: c.Query("action"),
|
||||||
|
StartTimestamp: startTimestamp,
|
||||||
|
EndTimestamp: endTimestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||||
|
if logs == nil {
|
||||||
|
logs = make([]*model.Task, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": logs,
|
||||||
|
})
|
||||||
|
}
|
129
dto/suno.go
Normal file
129
dto/suno.go
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskData interface {
|
||||||
|
SunoDataResponse | []SunoDataResponse | string | any
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoSubmitReq struct {
|
||||||
|
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Mv string `json:"mv,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
Tags string `json:"tags,omitempty"`
|
||||||
|
ContinueAt float64 `json:"continue_at,omitempty"`
|
||||||
|
TaskID string `json:"task_id,omitempty"`
|
||||||
|
ContinueClipId string `json:"continue_clip_id,omitempty"`
|
||||||
|
MakeInstrumental bool `json:"make_instrumental"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FetchReq struct {
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoDataResponse struct {
|
||||||
|
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
|
||||||
|
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
||||||
|
Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||||
|
StartTime int64 `json:"start_time" gorm:"index"`
|
||||||
|
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||||
|
Data json.RawMessage `json:"data" gorm:"type:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoSong struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
VideoURL string `json:"video_url"`
|
||||||
|
AudioURL string `json:"audio_url"`
|
||||||
|
ImageURL string `json:"image_url"`
|
||||||
|
ImageLargeURL string `json:"image_large_url"`
|
||||||
|
MajorModelVersion string `json:"major_model_version"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Metadata SunoMetadata `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoMetadata struct {
|
||||||
|
Tags string `json:"tags"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
|
||||||
|
AudioPromptID interface{} `json:"audio_prompt_id"`
|
||||||
|
Duration interface{} `json:"duration"`
|
||||||
|
ErrorType interface{} `json:"error_type"`
|
||||||
|
ErrorMessage interface{} `json:"error_message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoLyrics struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const TaskSuccessCode = "success"
|
||||||
|
|
||||||
|
type TaskResponse[T TaskData] struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data T `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TaskResponse[T]) IsSuccess() bool {
|
||||||
|
return t.Code == TaskSuccessCode
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskDto struct {
|
||||||
|
TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
|
||||||
|
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
|
||||||
|
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
SubmitTime int64 `json:"submit_time"`
|
||||||
|
StartTime int64 `json:"start_time"`
|
||||||
|
FinishTime int64 `json:"finish_time"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoGoAPISubmitReq struct {
|
||||||
|
CustomMode bool `json:"custom_mode"`
|
||||||
|
|
||||||
|
Input SunoGoAPISubmitReqInput `json:"input"`
|
||||||
|
|
||||||
|
NotifyHook string `json:"notify_hook,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SunoGoAPISubmitReqInput struct {
|
||||||
|
GptDescriptionPrompt string `json:"gpt_description_prompt"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Mv string `json:"mv"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Tags string `json:"tags"`
|
||||||
|
ContinueAt float64 `json:"continue_at"`
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
ContinueClipId string `json:"continue_clip_id"`
|
||||||
|
MakeInstrumental bool `json:"make_instrumental"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GoAPITaskResponse[T any] struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data T `json:"data"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GoAPITaskResponseData struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GoAPIFetchResponseData struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Input string `json:"input"`
|
||||||
|
Clips map[string]SunoSong `json:"clips"`
|
||||||
|
}
|
10
dto/task.go
Normal file
10
dto/task.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type TaskError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data any `json:"data"`
|
||||||
|
StatusCode int `json:"-"`
|
||||||
|
LocalError bool `json:"-"`
|
||||||
|
Error error `json:"-"`
|
||||||
|
}
|
3
main.go
3
main.go
@ -92,6 +92,9 @@ func main() {
|
|||||||
common.SafeGoroutine(func() {
|
common.SafeGoroutine(func() {
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
})
|
})
|
||||||
|
common.SafeGoroutine(func() {
|
||||||
|
controller.UpdateTaskBulk()
|
||||||
|
})
|
||||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = midjourneyModel
|
modelRequest.Model = midjourneyModel
|
||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
|
} else if strings.Contains(c.Request.URL.Path, "/suno/") {
|
||||||
|
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
|
||||||
|
if relayMode == relayconstant.RelayModeSunoFetch ||
|
||||||
|
relayMode == relayconstant.RelayModeSunoFetchByID {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
} else {
|
||||||
|
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
|
||||||
|
modelRequest.Model = modelName
|
||||||
|
}
|
||||||
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
|
@ -140,6 +140,10 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = db.AutoMigrate(&Task{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
common.SysLog("database migrated")
|
common.SysLog("database migrated")
|
||||||
err = createRootAccountIfNeed()
|
err = createRootAccountIfNeed()
|
||||||
return err
|
return err
|
||||||
|
@ -41,6 +41,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||||
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
||||||
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
|
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
|
||||||
|
common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
|
||||||
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
||||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||||
@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.DisplayTokenStatEnabled = boolValue
|
common.DisplayTokenStatEnabled = boolValue
|
||||||
case "DrawingEnabled":
|
case "DrawingEnabled":
|
||||||
common.DrawingEnabled = boolValue
|
common.DrawingEnabled = boolValue
|
||||||
|
case "TaskEnabled":
|
||||||
|
common.TaskEnabled = boolValue
|
||||||
case "DataExportEnabled":
|
case "DataExportEnabled":
|
||||||
common.DataExportEnabled = boolValue
|
common.DataExportEnabled = boolValue
|
||||||
case "DefaultCollapseSidebar":
|
case "DefaultCollapseSidebar":
|
||||||
|
304
model/task.go
Normal file
304
model/task.go
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/constant"
|
||||||
|
commonRelay "one-api/relay/common"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskStatusNotStart TaskStatus = "NOT_START"
|
||||||
|
TaskStatusSubmitted = "SUBMITTED"
|
||||||
|
TaskStatusQueued = "QUEUED"
|
||||||
|
TaskStatusInProgress = "IN_PROGRESS"
|
||||||
|
TaskStatusFailure = "FAILURE"
|
||||||
|
TaskStatusSuccess = "SUCCESS"
|
||||||
|
TaskStatusUnknown = "UNKNOWN"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
|
||||||
|
CreatedAt int64 `json:"created_at" gorm:"index"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id
|
||||||
|
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
|
||||||
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
|
ChannelId int `json:"channel_id" gorm:"index"`
|
||||||
|
Quota int `json:"quota"`
|
||||||
|
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
||||||
|
Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||||
|
StartTime int64 `json:"start_time" gorm:"index"`
|
||||||
|
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||||
|
Progress string `json:"progress" gorm:"type:varchar(20);index"`
|
||||||
|
Properties Properties `json:"properties" gorm:"type:json"`
|
||||||
|
|
||||||
|
Data json.RawMessage `json:"data" gorm:"type:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) SetData(data any) {
|
||||||
|
b, _ := json.Marshal(data)
|
||||||
|
t.Data = json.RawMessage(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) GetData(v any) error {
|
||||||
|
err := json.Unmarshal(t.Data, &v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Properties struct {
|
||||||
|
Input string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Properties) Scan(val interface{}) error {
|
||||||
|
bytesValue, _ := val.([]byte)
|
||||||
|
return json.Unmarshal(bytesValue, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Properties) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||||
|
type SyncTaskQueryParams struct {
|
||||||
|
Platform constant.TaskPlatform
|
||||||
|
ChannelID string
|
||||||
|
TaskID string
|
||||||
|
UserID string
|
||||||
|
Action string
|
||||||
|
Status string
|
||||||
|
StartTimestamp int64
|
||||||
|
EndTimestamp int64
|
||||||
|
UserIDs []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
|
||||||
|
t := &Task{
|
||||||
|
UserId: relayInfo.UserId,
|
||||||
|
SubmitTime: time.Now().Unix(),
|
||||||
|
Status: TaskStatusNotStart,
|
||||||
|
Progress: "0%",
|
||||||
|
ChannelId: relayInfo.ChannelId,
|
||||||
|
Platform: platform,
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
|
||||||
|
var tasks []*Task
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 初始化查询构建器
|
||||||
|
query := DB.Where("user_id = ?", userId)
|
||||||
|
|
||||||
|
if queryParams.TaskID != "" {
|
||||||
|
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||||
|
}
|
||||||
|
if queryParams.Action != "" {
|
||||||
|
query = query.Where("action = ?", queryParams.Action)
|
||||||
|
}
|
||||||
|
if queryParams.Status != "" {
|
||||||
|
query = query.Where("status = ?", queryParams.Status)
|
||||||
|
}
|
||||||
|
if queryParams.Platform != "" {
|
||||||
|
query = query.Where("platform = ?", queryParams.Platform)
|
||||||
|
}
|
||||||
|
if queryParams.StartTimestamp != 0 {
|
||||||
|
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
|
||||||
|
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||||
|
}
|
||||||
|
if queryParams.EndTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取数据
|
||||||
|
err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
|
||||||
|
var tasks []*Task
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 初始化查询构建器
|
||||||
|
query := DB
|
||||||
|
|
||||||
|
// 添加过滤条件
|
||||||
|
if queryParams.ChannelID != "" {
|
||||||
|
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||||
|
}
|
||||||
|
if queryParams.Platform != "" {
|
||||||
|
query = query.Where("platform = ?", queryParams.Platform)
|
||||||
|
}
|
||||||
|
if queryParams.UserID != "" {
|
||||||
|
query = query.Where("user_id = ?", queryParams.UserID)
|
||||||
|
}
|
||||||
|
if len(queryParams.UserIDs) != 0 {
|
||||||
|
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||||
|
}
|
||||||
|
if queryParams.TaskID != "" {
|
||||||
|
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||||
|
}
|
||||||
|
if queryParams.Action != "" {
|
||||||
|
query = query.Where("action = ?", queryParams.Action)
|
||||||
|
}
|
||||||
|
if queryParams.Status != "" {
|
||||||
|
query = query.Where("status = ?", queryParams.Status)
|
||||||
|
}
|
||||||
|
if queryParams.StartTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||||
|
}
|
||||||
|
if queryParams.EndTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取数据
|
||||||
|
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllUnFinishSyncTasks(limit int) []*Task {
|
||||||
|
var tasks []*Task
|
||||||
|
var err error
|
||||||
|
// get all tasks progress is not 100%
|
||||||
|
err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
|
||||||
|
if taskId == "" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
var task *Task
|
||||||
|
var err error
|
||||||
|
err = DB.Where("task_id = ?", taskId).First(&task).Error
|
||||||
|
exist, err := RecordExist(err)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
return task, exist, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
|
||||||
|
if taskId == "" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
var task *Task
|
||||||
|
var err error
|
||||||
|
err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
|
||||||
|
First(&task).Error
|
||||||
|
exist, err := RecordExist(err)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
return task, exist, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var task []*Task
|
||||||
|
var err error
|
||||||
|
err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
|
||||||
|
Find(&task).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskUpdateProgress(id int64, progress string) error {
|
||||||
|
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Task *Task) Insert() error {
|
||||||
|
var err error
|
||||||
|
err = DB.Create(Task).Error
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Task *Task) Update() error {
|
||||||
|
var err error
|
||||||
|
err = DB.Save(Task).Error
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
|
||||||
|
if len(TaskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return DB.Model(&Task{}).
|
||||||
|
Where("task_id in (?)", TaskIds).
|
||||||
|
Updates(params).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
|
||||||
|
if len(taskIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return DB.Model(&Task{}).
|
||||||
|
Where("id in (?)", taskIDs).
|
||||||
|
Updates(params).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return DB.Model(&Task{}).
|
||||||
|
Where("id in (?)", ids).
|
||||||
|
Updates(params).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskQuotaUsage struct {
|
||||||
|
Mode string `json:"mode"`
|
||||||
|
Count float64 `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
|
||||||
|
query := DB.Model(Task{})
|
||||||
|
// 添加过滤条件
|
||||||
|
if queryParams.ChannelID != "" {
|
||||||
|
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||||
|
}
|
||||||
|
if queryParams.UserID != "" {
|
||||||
|
query = query.Where("user_id = ?", queryParams.UserID)
|
||||||
|
}
|
||||||
|
if len(queryParams.UserIDs) != 0 {
|
||||||
|
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||||
|
}
|
||||||
|
if queryParams.TaskID != "" {
|
||||||
|
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||||
|
}
|
||||||
|
if queryParams.Action != "" {
|
||||||
|
query = query.Where("action = ?", queryParams.Action)
|
||||||
|
}
|
||||||
|
if queryParams.Status != "" {
|
||||||
|
query = query.Where("status = ?", queryParams.Status)
|
||||||
|
}
|
||||||
|
if queryParams.StartTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||||
|
}
|
||||||
|
if queryParams.EndTimestamp != 0 {
|
||||||
|
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||||
|
}
|
||||||
|
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
||||||
|
return stat, err
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -75,3 +77,13 @@ func batchUpdate() {
|
|||||||
}
|
}
|
||||||
common.SysLog("batch update finished")
|
common.SysLog("batch update finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RecordExist(err error) (bool, error) {
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
@ -19,3 +19,22 @@ type Adaptor interface {
|
|||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TaskAdaptor interface {
|
||||||
|
Init(info *relaycommon.TaskRelayInfo)
|
||||||
|
|
||||||
|
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
|
||||||
|
|
||||||
|
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
|
||||||
|
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
|
||||||
|
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
|
||||||
|
|
||||||
|
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||||
|
|
||||||
|
GetModelList() []string
|
||||||
|
GetChannelName() string
|
||||||
|
|
||||||
|
// FetchTask
|
||||||
|
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||||
|
}
|
||||||
|
@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
|||||||
_ = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
fullRequestURL, err := a.BuildRequestURL(info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(requestBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.BuildRequestHeader(c, req, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
resp, err := doRequest(c, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
172
relay/channel/task/suno/adaptor.go
Normal file
172
relay/channel/task/suno/adaptor.go
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package suno
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
action := strings.ToUpper(c.Param("action"))
|
||||||
|
|
||||||
|
var sunoRequest *dto.SunoSubmitReq
|
||||||
|
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = actionValidate(c, sunoRequest, action)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if sunoRequest.ContinueClipId != "" {
|
||||||
|
if sunoRequest.TaskID == "" {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info.OriginTaskID = sunoRequest.TaskID
|
||||||
|
}
|
||||||
|
|
||||||
|
info.Action = action
|
||||||
|
c.Set("task_request", sunoRequest)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
|
baseURL := info.BaseUrl
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
||||||
|
return fullRequestURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||||
|
sunoRequest, ok := c.Get("task_request")
|
||||||
|
if !ok {
|
||||||
|
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(sunoRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var sunoResponse dto.TaskResponse[string]
|
||||||
|
err = json.Unmarshal(responseBody, &sunoResponse)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !sunoResponse.IsSuccess() {
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return sunoResponse.Data, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||||
|
byteBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("Get Task error: %v", err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer req.Body.Close()
|
||||||
|
// 设置超时时间
|
||||||
|
timeout := time.Second * 15
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
// 使用带有超时的 context 创建新的请求
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
|
||||||
|
switch action {
|
||||||
|
case constant.SunoActionMusic:
|
||||||
|
if sunoRequest.Mv == "" {
|
||||||
|
sunoRequest.Mv = "chirp-v3-0"
|
||||||
|
}
|
||||||
|
case constant.SunoActionLyrics:
|
||||||
|
if sunoRequest.Prompt == "" {
|
||||||
|
err = fmt.Errorf("prompt_empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("invalid_action")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
7
relay/channel/task/suno/models.go
Normal file
7
relay/channel/task/suno/models.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package suno
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"suno_music", "suno_lyrics",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "suno"
|
@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
|||||||
func (info *RelayInfo) SetIsStream(isStream bool) {
|
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||||
info.IsStream = isStream
|
info.IsStream = isStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TaskRelayInfo struct {
|
||||||
|
ChannelType int
|
||||||
|
ChannelId int
|
||||||
|
TokenId int
|
||||||
|
UserId int
|
||||||
|
Group string
|
||||||
|
StartTime time.Time
|
||||||
|
ApiType int
|
||||||
|
RelayMode int
|
||||||
|
UpstreamModelName string
|
||||||
|
RequestURLPath string
|
||||||
|
ApiKey string
|
||||||
|
BaseUrl string
|
||||||
|
|
||||||
|
Action string
|
||||||
|
OriginTaskID string
|
||||||
|
|
||||||
|
ConsumeQuota bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
|
info := &TaskRelayInfo{
|
||||||
|
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||||
|
BaseUrl: c.GetString("base_url"),
|
||||||
|
RequestURLPath: c.Request.URL.String(),
|
||||||
|
ChannelType: channelType,
|
||||||
|
ChannelId: channelId,
|
||||||
|
TokenId: tokenId,
|
||||||
|
UserId: userId,
|
||||||
|
Group: group,
|
||||||
|
StartTime: startTime,
|
||||||
|
ApiType: apiType,
|
||||||
|
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
|
}
|
||||||
|
if info.BaseUrl == "" {
|
||||||
|
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RelayModeUnknown = iota
|
RelayModeUnknown = iota
|
||||||
@ -26,6 +29,9 @@ const (
|
|||||||
RelayModeMidjourneyModal
|
RelayModeMidjourneyModal
|
||||||
RelayModeMidjourneyShorten
|
RelayModeMidjourneyShorten
|
||||||
RelayModeSwapFace
|
RelayModeSwapFace
|
||||||
|
RelayModeSunoFetch
|
||||||
|
RelayModeSunoFetchByID
|
||||||
|
RelayModeSunoSubmit
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
|
|||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Path2RelaySuno(method, path string) int {
|
||||||
|
relayMode := RelayModeUnknown
|
||||||
|
if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
|
||||||
|
relayMode = RelayModeSunoFetch
|
||||||
|
} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
|
||||||
|
relayMode = RelayModeSunoFetchByID
|
||||||
|
} else if strings.Contains(path, "/submit/") {
|
||||||
|
relayMode = RelayModeSunoSubmit
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
commonconstant "one-api/constant"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ali"
|
"one-api/relay/channel/ali"
|
||||||
"one-api/relay/channel/aws"
|
"one-api/relay/channel/aws"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
"one-api/relay/channel/perplexity"
|
"one-api/relay/channel/perplexity"
|
||||||
|
"one-api/relay/channel/task/suno"
|
||||||
"one-api/relay/channel/tencent"
|
"one-api/relay/channel/tencent"
|
||||||
"one-api/relay/channel/xunfei"
|
"one-api/relay/channel/xunfei"
|
||||||
"one-api/relay/channel/zhipu"
|
"one-api/relay/channel/zhipu"
|
||||||
@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||||
|
switch platform {
|
||||||
|
//case constant.APITypeAIProxyLibrary:
|
||||||
|
// return &aiproxy.Adaptor{}
|
||||||
|
case commonconstant.TaskPlatformSuno:
|
||||||
|
return &suno.TaskAdaptor{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
242
relay/relay_task.go
Normal file
242
relay/relay_task.go
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
Task 任务通过平台、Action 区分任务
|
||||||
|
*/
|
||||||
|
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||||
|
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||||
|
relayInfo := relaycommon.GenTaskRelayInfo(c)
|
||||||
|
|
||||||
|
adaptor := GetTaskAdaptor(platform)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
// get & validate taskRequest 获取并验证文本请求
|
||||||
|
taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
|
||||||
|
if taskErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||||
|
modelPrice, success := common.GetModelPrice(modelName, true)
|
||||||
|
if !success {
|
||||||
|
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
|
||||||
|
if !ok {
|
||||||
|
modelPrice = 0.1
|
||||||
|
} else {
|
||||||
|
modelPrice = defaultPrice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预扣
|
||||||
|
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||||
|
ratio := modelPrice * groupRatio
|
||||||
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
quota := int(ratio * common.QuotaPerUnit)
|
||||||
|
if userQuota-quota < 0 {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayInfo.OriginTaskID != "" {
|
||||||
|
originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exist {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if originTask.ChannelId != relayInfo.ChannelId {
|
||||||
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|
||||||
|
relayInfo.BaseUrl = channel.GetBaseURL()
|
||||||
|
relayInfo.ChannelId = originTask.ChannelId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// build body
|
||||||
|
requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// do request
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// handle response
|
||||||
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
|
responseBody, _ := io.ReadAll(resp.Body)
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
// release quota
|
||||||
|
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||||
|
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["model_price"] = modelPrice
|
||||||
|
other["group_ratio"] = groupRatio
|
||||||
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
|
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
|
if taskErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
relayInfo.ConsumeQuota = true
|
||||||
|
// insert task
|
||||||
|
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
||||||
|
task.TaskID = taskID
|
||||||
|
task.Quota = quota
|
||||||
|
task.Data = taskData
|
||||||
|
err = task.Insert()
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||||
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||||
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||||
|
respBuilder, ok := fetchRespBuilders[relayMode]
|
||||||
|
if !ok {
|
||||||
|
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, taskErr := respBuilder(c)
|
||||||
|
if taskErr != nil {
|
||||||
|
return taskErr
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
var condition = struct {
|
||||||
|
IDs []any `json:"ids"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
}{}
|
||||||
|
err := c.BindJSON(&condition)
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var tasks []any
|
||||||
|
if len(condition.IDs) > 0 {
|
||||||
|
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, task := range taskModels {
|
||||||
|
tasks = append(tasks, TaskModel2Dto(task))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tasks = make([]any, 0)
|
||||||
|
}
|
||||||
|
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: tasks,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exist {
|
||||||
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: TaskModel2Dto(originTask),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||||
|
return &dto.TaskDto{
|
||||||
|
TaskID: task.TaskID,
|
||||||
|
Action: task.Action,
|
||||||
|
Status: string(task.Status),
|
||||||
|
FailReason: task.FailReason,
|
||||||
|
SubmitTime: task.SubmitTime,
|
||||||
|
StartTime: task.StartTime,
|
||||||
|
FinishTime: task.FinishTime,
|
||||||
|
Progress: task.Progress,
|
||||||
|
Data: task.Data,
|
||||||
|
}
|
||||||
|
}
|
@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
mjRoute := apiRouter.Group("/mj")
|
mjRoute := apiRouter.Group("/mj")
|
||||||
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
|
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
|
||||||
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
|
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
|
||||||
|
|
||||||
|
taskRoute := apiRouter.Group("/task")
|
||||||
|
{
|
||||||
|
taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
|
||||||
|
taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayMjModeRouter := router.Group("/:mode/mj")
|
relayMjModeRouter := router.Group("/:mode/mj")
|
||||||
registerMjRouterGroup(relayMjModeRouter)
|
registerMjRouterGroup(relayMjModeRouter)
|
||||||
//relayMjRouter.Use()
|
//relayMjRouter.Use()
|
||||||
|
|
||||||
|
relaySunoRouter := router.Group("/suno")
|
||||||
|
relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
|
{
|
||||||
|
relaySunoRouter.POST("/submit/:action", controller.RelayTask)
|
||||||
|
relaySunoRouter.POST("/fetch", controller.RelayTask)
|
||||||
|
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||||
|
@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
|
|||||||
openaiErr.StatusCode = intCode
|
openaiErr.StatusCode = intCode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
|
||||||
|
openaiErr := TaskErrorWrapper(err, code, statusCode)
|
||||||
|
openaiErr.LocalError = true
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
|
||||||
|
text := err.Error()
|
||||||
|
|
||||||
|
// 定义一个正则表达式匹配URL
|
||||||
|
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
|
||||||
|
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||||
|
text = "请求上游地址失败"
|
||||||
|
}
|
||||||
|
//避免暴露内部错误
|
||||||
|
|
||||||
|
taskError := &dto.TaskError{
|
||||||
|
Code: code,
|
||||||
|
Message: text,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
|
||||||
|
return taskError
|
||||||
|
}
|
||||||
|
10
service/task.go
Normal file
10
service/task.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/constant"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
|
||||||
|
return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
|
||||||
|
}
|
@ -23,6 +23,7 @@ import Chat from './pages/Chat';
|
|||||||
import { Layout } from '@douyinfe/semi-ui';
|
import { Layout } from '@douyinfe/semi-ui';
|
||||||
import Midjourney from './pages/Midjourney';
|
import Midjourney from './pages/Midjourney';
|
||||||
import Pricing from './pages/Pricing/index.js';
|
import Pricing from './pages/Pricing/index.js';
|
||||||
|
import Task from "./pages/Task/index.js";
|
||||||
// import Detail from './pages/Detail';
|
// import Detail from './pages/Detail';
|
||||||
|
|
||||||
const Home = lazy(() => import('./pages/Home'));
|
const Home = lazy(() => import('./pages/Home'));
|
||||||
@ -220,6 +221,16 @@ function App() {
|
|||||||
</PrivateRoute>
|
</PrivateRoute>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
<Route
|
||||||
|
path='/task'
|
||||||
|
element={
|
||||||
|
<PrivateRoute>
|
||||||
|
<Suspense fallback={<Loading></Loading>}>
|
||||||
|
<Task />
|
||||||
|
</Suspense>
|
||||||
|
</PrivateRoute>
|
||||||
|
}
|
||||||
|
/>
|
||||||
<Route
|
<Route
|
||||||
path='/pricing'
|
path='/pricing'
|
||||||
element={
|
element={
|
||||||
|
@ -14,7 +14,7 @@ import {
|
|||||||
import '../index.css';
|
import '../index.css';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
IconCalendarClock,
|
IconCalendarClock, IconChecklistStroked,
|
||||||
IconComment,
|
IconComment,
|
||||||
IconCreditCard,
|
IconCreditCard,
|
||||||
IconGift,
|
IconGift,
|
||||||
@ -58,6 +58,7 @@ const SiderBar = () => {
|
|||||||
chat: '/chat',
|
chat: '/chat',
|
||||||
detail: '/detail',
|
detail: '/detail',
|
||||||
pricing: '/pricing',
|
pricing: '/pricing',
|
||||||
|
task: '/task',
|
||||||
};
|
};
|
||||||
|
|
||||||
const headerButtons = useMemo(
|
const headerButtons = useMemo(
|
||||||
@ -142,6 +143,16 @@ const SiderBar = () => {
|
|||||||
? 'semi-navigation-item-normal'
|
? 'semi-navigation-item-normal'
|
||||||
: 'tableHiddle',
|
: 'tableHiddle',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
text: '异步任务',
|
||||||
|
itemKey: 'task',
|
||||||
|
to: '/task',
|
||||||
|
icon: <IconChecklistStroked />,
|
||||||
|
className:
|
||||||
|
localStorage.getItem('enable_task') === 'true'
|
||||||
|
? 'semi-navigation-item-normal'
|
||||||
|
: 'tableHiddle',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
text: '设置',
|
text: '设置',
|
||||||
itemKey: 'setting',
|
itemKey: 'setting',
|
||||||
@ -158,6 +169,7 @@ const SiderBar = () => {
|
|||||||
[
|
[
|
||||||
localStorage.getItem('enable_data_export'),
|
localStorage.getItem('enable_data_export'),
|
||||||
localStorage.getItem('enable_drawing'),
|
localStorage.getItem('enable_drawing'),
|
||||||
|
localStorage.getItem('enable_task'),
|
||||||
localStorage.getItem('chat_link'),
|
localStorage.getItem('chat_link'),
|
||||||
isAdmin(),
|
isAdmin(),
|
||||||
],
|
],
|
||||||
|
400
web/src/components/TaskLogsTable.js
Normal file
400
web/src/components/TaskLogsTable.js
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
import React, { useEffect, useState } from 'react';
|
||||||
|
import { Label } from 'semantic-ui-react';
|
||||||
|
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
|
||||||
|
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
Tag,
|
||||||
|
Form,
|
||||||
|
Button,
|
||||||
|
Layout,
|
||||||
|
Modal,
|
||||||
|
Typography, Progress, Card
|
||||||
|
} from '@douyinfe/semi-ui';
|
||||||
|
import { ITEMS_PER_PAGE } from '../constants';
|
||||||
|
|
||||||
|
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
|
||||||
|
'light-blue', 'lime', 'orange', 'pink',
|
||||||
|
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
const renderTimestamp = (timestampInSeconds) => {
|
||||||
|
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
|
||||||
|
|
||||||
|
const year = date.getFullYear(); // 获取年份
|
||||||
|
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
|
||||||
|
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
|
||||||
|
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
|
||||||
|
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
|
||||||
|
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
|
||||||
|
|
||||||
|
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
|
||||||
|
};
|
||||||
|
|
||||||
|
function renderDuration(submit_time, finishTime) {
|
||||||
|
// 确保startTime和finishTime都是有效的时间戳
|
||||||
|
if (!submit_time || !finishTime) return 'N/A';
|
||||||
|
|
||||||
|
// 将时间戳转换为Date对象
|
||||||
|
const start = new Date(submit_time);
|
||||||
|
const finish = new Date(finishTime);
|
||||||
|
|
||||||
|
// 计算时间差(毫秒)
|
||||||
|
const durationMs = finish - start;
|
||||||
|
|
||||||
|
// 将时间差转换为秒,并保留一位小数
|
||||||
|
const durationSec = (durationMs / 1000).toFixed(1);
|
||||||
|
|
||||||
|
// 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
|
||||||
|
const color = durationSec > 60 ? 'red' : 'green';
|
||||||
|
|
||||||
|
// 返回带有样式的颜色标签
|
||||||
|
return (
|
||||||
|
<Tag color={color} size="large">
|
||||||
|
{durationSec} 秒
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const LogsTable = () => {
|
||||||
|
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||||
|
const [modalContent, setModalContent] = useState('');
|
||||||
|
const isAdminUser = isAdmin();
|
||||||
|
const columns = [
|
||||||
|
{
|
||||||
|
title: "提交时间",
|
||||||
|
dataIndex: 'submit_time',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{text ? renderTimestamp(text) : "-"}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "结束时间",
|
||||||
|
dataIndex: 'finish_time',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{text ? renderTimestamp(text) : "-"}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '进度',
|
||||||
|
dataIndex: 'progress',
|
||||||
|
width: 50,
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{
|
||||||
|
// 转换例如100%为数字100,如果text未定义,返回0
|
||||||
|
isNaN(text.replace('%', '')) ? text : <Progress width={42} type="circle" showInfo={true} percent={Number(text.replace('%', '') || 0)} aria-label="drawing progress" />
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '花费时间',
|
||||||
|
dataIndex: 'finish_time', // 以finish_time作为dataIndex
|
||||||
|
key: 'finish_time',
|
||||||
|
render: (finish, record) => {
|
||||||
|
// 假设record.start_time是存在的,并且finish是完成时间的时间戳
|
||||||
|
return <>
|
||||||
|
{
|
||||||
|
finish ? renderDuration(record.submit_time, finish) : "-"
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "渠道",
|
||||||
|
dataIndex: 'channel_id',
|
||||||
|
className: isAdminUser ? 'tableShow' : 'tableHiddle',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Tag
|
||||||
|
color={colors[parseInt(text) % colors.length]}
|
||||||
|
size='large'
|
||||||
|
onClick={() => {
|
||||||
|
copyText(text); // 假设copyText是用于文本复制的函数
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{' '}
|
||||||
|
{text}{' '}
|
||||||
|
</Tag>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "平台",
|
||||||
|
dataIndex: 'platform',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{renderPlatform(text)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '类型',
|
||||||
|
dataIndex: 'action',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{renderType(text)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '任务ID(点击查看详情)',
|
||||||
|
dataIndex: 'task_id',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (<Typography.Text
|
||||||
|
ellipsis={{ showTooltip: true }}
|
||||||
|
//style={{width: 100}}
|
||||||
|
onClick={() => {
|
||||||
|
setModalContent(JSON.stringify(record, null, 2));
|
||||||
|
setIsModalOpen(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
{text}
|
||||||
|
</div>
|
||||||
|
</Typography.Text>);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '任务状态',
|
||||||
|
dataIndex: 'status',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{renderStatus(text)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
title: '失败原因',
|
||||||
|
dataIndex: 'fail_reason',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||||
|
if (!text) {
|
||||||
|
return '无';
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Typography.Text
|
||||||
|
ellipsis={{ showTooltip: true }}
|
||||||
|
style={{ width: 100 }}
|
||||||
|
onClick={() => {
|
||||||
|
setModalContent(text);
|
||||||
|
setIsModalOpen(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{text}
|
||||||
|
</Typography.Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
const [logs, setLogs] = useState([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [activePage, setActivePage] = useState(1);
|
||||||
|
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||||
|
const [logType] = useState(0);
|
||||||
|
|
||||||
|
let now = new Date();
|
||||||
|
// 初始化start_timestamp为前一天
|
||||||
|
let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate());
|
||||||
|
const [inputs, setInputs] = useState({
|
||||||
|
channel_id: '',
|
||||||
|
task_id: '',
|
||||||
|
start_timestamp: timestamp2string(zeroNow.getTime() /1000),
|
||||||
|
end_timestamp: '',
|
||||||
|
});
|
||||||
|
const { channel_id, task_id, start_timestamp, end_timestamp } = inputs;
|
||||||
|
|
||||||
|
const handleInputChange = (value, name) => {
|
||||||
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
const setLogsFormat = (logs) => {
|
||||||
|
for (let i = 0; i < logs.length; i++) {
|
||||||
|
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||||
|
logs[i].key = '' + logs[i].id;
|
||||||
|
}
|
||||||
|
// data.key = '' + data.id
|
||||||
|
setLogs(logs);
|
||||||
|
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||||
|
// console.log(logCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadLogs = async (startIdx) => {
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
let url = '';
|
||||||
|
let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000);
|
||||||
|
let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 );
|
||||||
|
if (isAdminUser) {
|
||||||
|
url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
|
} else {
|
||||||
|
url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
|
}
|
||||||
|
const res = await API.get(url);
|
||||||
|
let { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
if (startIdx === 0) {
|
||||||
|
setLogsFormat(data);
|
||||||
|
} else {
|
||||||
|
let newLogs = [...logs];
|
||||||
|
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
|
||||||
|
setLogsFormat(newLogs);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
setLoading(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
|
||||||
|
|
||||||
|
const handlePageChange = page => {
|
||||||
|
setActivePage(page);
|
||||||
|
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
|
||||||
|
// In this case we have to load more data and then append them.
|
||||||
|
loadLogs(page - 1).then(r => {
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const refresh = async () => {
|
||||||
|
// setLoading(true);
|
||||||
|
setActivePage(1);
|
||||||
|
await loadLogs(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
const copyText = async (text) => {
|
||||||
|
if (await copy(text)) {
|
||||||
|
showSuccess('已复制:' + text);
|
||||||
|
} else {
|
||||||
|
// setSearchKeyword(text);
|
||||||
|
Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
refresh().then();
|
||||||
|
}, [logType]);
|
||||||
|
|
||||||
|
const renderType = (type) => {
|
||||||
|
switch (type) {
|
||||||
|
case 'MUSIC':
|
||||||
|
return <Label basic color='grey'> 生成音乐 </Label>;
|
||||||
|
case 'LYRICS':
|
||||||
|
return <Label basic color='pink'> 生成歌词 </Label>;
|
||||||
|
|
||||||
|
default:
|
||||||
|
return <Label basic color='black'> 未知 </Label>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderPlatform = (type) => {
|
||||||
|
switch (type) {
|
||||||
|
case "suno":
|
||||||
|
return <Label basic color='green'> Suno </Label>;
|
||||||
|
default:
|
||||||
|
return <Label basic color='black'> 未知 </Label>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderStatus = (type) => {
|
||||||
|
switch (type) {
|
||||||
|
case 'SUCCESS':
|
||||||
|
return <Label basic color='green'> 成功 </Label>;
|
||||||
|
case 'NOT_START':
|
||||||
|
return <Label basic color='black'> 未启动 </Label>;
|
||||||
|
case 'SUBMITTED':
|
||||||
|
return <Label basic color='yellow'> 队列中 </Label>;
|
||||||
|
case 'IN_PROGRESS':
|
||||||
|
return <Label basic color='blue'> 执行中 </Label>;
|
||||||
|
case 'FAILURE':
|
||||||
|
return <Label basic color='red'> 失败 </Label>;
|
||||||
|
case 'QUEUED':
|
||||||
|
return <Label basic color='red'> 排队中 </Label>;
|
||||||
|
case 'UNKNOWN':
|
||||||
|
return <Label basic color='red'> 未知 </Label>;
|
||||||
|
case '':
|
||||||
|
return <Label basic color='black'> 正在提交 </Label>;
|
||||||
|
default:
|
||||||
|
return <Label basic color='black'> 未知 </Label>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
|
||||||
|
<Layout>
|
||||||
|
<Form layout='horizontal' labelPosition='inset'>
|
||||||
|
<>
|
||||||
|
{isAdminUser && <Form.Input field="channel_id" label='渠道 ID' style={{ width: '236px', marginBottom: '10px' }} value={channel_id}
|
||||||
|
placeholder={'可选值'} name='channel_id'
|
||||||
|
onChange={value => handleInputChange(value, 'channel_id')} />
|
||||||
|
}
|
||||||
|
<Form.Input field="task_id" label={"任务 ID"} style={{ width: '236px', marginBottom: '10px' }} value={task_id}
|
||||||
|
placeholder={"可选值"}
|
||||||
|
name='task_id'
|
||||||
|
onChange={value => handleInputChange(value, 'task_id')} />
|
||||||
|
|
||||||
|
<Form.DatePicker field="start_timestamp" label={"起始时间"} style={{ width: '236px', marginBottom: '10px' }}
|
||||||
|
initValue={start_timestamp}
|
||||||
|
value={start_timestamp} type='dateTime'
|
||||||
|
name='start_timestamp'
|
||||||
|
onChange={value => handleInputChange(value, 'start_timestamp')} />
|
||||||
|
<Form.DatePicker field="end_timestamp" fluid label={"结束时间"} style={{ width: '236px', marginBottom: '10px' }}
|
||||||
|
initValue={end_timestamp}
|
||||||
|
value={end_timestamp} type='dateTime'
|
||||||
|
name='end_timestamp'
|
||||||
|
onChange={value => handleInputChange(value, 'end_timestamp')} />
|
||||||
|
<Button label={"查询"} type="primary" htmlType="submit" className="btn-margin-right"
|
||||||
|
onClick={refresh}>查询</Button>
|
||||||
|
</>
|
||||||
|
</Form>
|
||||||
|
<Card>
|
||||||
|
<Table columns={columns} dataSource={pageData} pagination={{
|
||||||
|
currentPage: activePage,
|
||||||
|
pageSize: ITEMS_PER_PAGE,
|
||||||
|
total: logCount,
|
||||||
|
pageSizeOpts: [10, 20, 50, 100],
|
||||||
|
onPageChange: handlePageChange,
|
||||||
|
}} loading={loading} />
|
||||||
|
</Card>
|
||||||
|
<Modal
|
||||||
|
visible={isModalOpen}
|
||||||
|
onOk={() => setIsModalOpen(false)}
|
||||||
|
onCancel={() => setIsModalOpen(false)}
|
||||||
|
closable={null}
|
||||||
|
bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
|
||||||
|
width={800} // 设置模态框宽度
|
||||||
|
>
|
||||||
|
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
|
||||||
|
</Modal>
|
||||||
|
</Layout>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default LogsTable;
|
@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'blue',
|
color: 'blue',
|
||||||
label: 'Midjourney Proxy Plus',
|
label: 'Midjourney Proxy Plus',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 36,
|
||||||
|
text: 'Suno API',
|
||||||
|
value: 36,
|
||||||
|
color: 'purple',
|
||||||
|
label: 'Suno API',
|
||||||
|
},
|
||||||
{ key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
|
{ key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
|
||||||
{
|
{
|
||||||
key: 14,
|
key: 14,
|
||||||
|
@ -6,6 +6,7 @@ export function setStatusData(data) {
|
|||||||
localStorage.setItem('quota_per_unit', data.quota_per_unit);
|
localStorage.setItem('quota_per_unit', data.quota_per_unit);
|
||||||
localStorage.setItem('display_in_currency', data.display_in_currency);
|
localStorage.setItem('display_in_currency', data.display_in_currency);
|
||||||
localStorage.setItem('enable_drawing', data.enable_drawing);
|
localStorage.setItem('enable_drawing', data.enable_drawing);
|
||||||
|
localStorage.setItem('enable_task', data.enable_task);
|
||||||
localStorage.setItem('enable_data_export', data.enable_data_export);
|
localStorage.setItem('enable_data_export', data.enable_data_export);
|
||||||
localStorage.setItem(
|
localStorage.setItem(
|
||||||
'data_export_default_time',
|
'data_export_default_time',
|
||||||
|
@ -126,6 +126,12 @@ const EditChannel = (props) => {
|
|||||||
'mj_uploads',
|
'mj_uploads',
|
||||||
];
|
];
|
||||||
break;
|
break;
|
||||||
|
case 36:
|
||||||
|
localModels = [
|
||||||
|
'suno_music',
|
||||||
|
'suno_lyrics',
|
||||||
|
];
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
localModels = getChannelModels(value);
|
localModels = getChannelModels(value);
|
||||||
break;
|
break;
|
||||||
@ -513,12 +519,32 @@ const EditChannel = (props) => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<div style={{ marginTop: 10 }}>
|
{inputs.type === 36 && (
|
||||||
|
<>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<Typography.Text strong>
|
||||||
|
注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
name='base_url'
|
||||||
|
placeholder={
|
||||||
|
'请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com '
|
||||||
|
}
|
||||||
|
onChange={(value) => {
|
||||||
|
handleInputChange('base_url', value);
|
||||||
|
}}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
<Typography.Text strong>名称:</Typography.Text>
|
<Typography.Text strong>名称:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
required
|
required
|
||||||
name='name'
|
name='name'
|
||||||
placeholder={'请为渠道命名'}
|
placeholder={'请为渠道命名'}
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
handleInputChange('name', value);
|
handleInputChange('name', value);
|
||||||
@ -758,7 +784,7 @@ const EditChannel = (props) => {
|
|||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
|
||||||
<>
|
<>
|
||||||
<div style={{ marginTop: 10 }}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>代理:</Typography.Text>
|
<Typography.Text strong>代理:</Typography.Text>
|
||||||
|
10
web/src/pages/Task/index.js
Normal file
10
web/src/pages/Task/index.js
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import TaskLogsTable from "../../components/TaskLogsTable.js";
|
||||||
|
|
||||||
|
const Task = () => (
|
||||||
|
<>
|
||||||
|
<TaskLogsTable />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
|
||||||
|
export default Task;
|
Loading…
Reference in New Issue
Block a user