diff --git a/README.md b/README.md index efd692f..5daa889 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,11 @@ 2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain 3. 选择你的bot,然后输入http(s)://你的网站地址/login 4. Telegram Bot 名称是bot username 去掉@后的字符串 +13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下: + + [x] /suno/submit/music + + [x] /suno/submit/lyrics + + [x] /suno/fetch + + [x] /suno/fetch/:id ## 模型支持 此版本额外支持以下模型: @@ -57,6 +62,7 @@ 5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md) 6. [零一万物](https://platform.lingyiwanwu.com/) 7. 自定义渠道,支持填入完整调用地址 +8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 @@ -105,6 +111,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234 ## Midjourney接口设置文档 [对接文档](Midjourney.md) +## Suno接口设置文档 +[对接文档](Suno.md) + ## 交流群 diff --git a/Suno.md b/Suno.md new file mode 100644 index 0000000..05fc9f3 --- /dev/null +++ b/Suno.md @@ -0,0 +1,37 @@ +# Suno API文档 + +**简介**:Suno API文档 + +## 模型列表 + +### Suno API支持 + +- suno_music (自定义模式、灵感模式、续写) +- suno_lyrics (生成歌词) + + +## 模型价格设置(在设置-运营设置-模型固定价格设置中设置) +```json +{ + "suno_music": 0.3, + "suno_lyrics": 0.01 +} +``` + +## 渠道设置 + +### 对接 Suno API + +1. +部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API) + +2. 在渠道管理中添加渠道,渠道类型选择**Suno API** + ,模型请参考上方模型列表 +3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080 +4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填 + +### 对接上游new api + +1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型 +2. **代理**填写上游new api的地址,例如:http://localhost:3000 +3. 密钥填写上游new api的密钥 \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index f5dbb3d..1bca141 100644 --- a/common/constants.go +++ b/common/constants.go @@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var DrawingEnabled = true +var TaskEnabled = true var DataExportEnabled = true var DataExportInterval = 5 // unit: minute var DataExportDefaultTime = "hour" // unit: minute @@ -208,8 +209,10 @@ const ( ChannelTypeAws = 33 ChannelTypeCohere = 34 ChannelTypeMiniMax = 35 + ChannelTypeSunoAPI = 36 ChannelTypeDummy // this one is only for count, do not add any channel after this + ) var ChannelBaseURLs = []string{ @@ -249,4 +252,5 @@ var ChannelBaseURLs = []string{ "", //33 "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 + "", //36 } diff --git a/constant/task.go b/constant/task.go new file mode 100644 index 0000000..1a68b81 --- /dev/null +++ b/constant/task.go @@ -0,0 +1,18 @@ +package constant + +type TaskPlatform string + +const ( + TaskPlatformSuno TaskPlatform = "suno" + TaskPlatformMidjourney = "mj" +) + +const ( + SunoActionMusic = "MUSIC" + SunoActionLyrics = "LYRICS" +) + +var SunoModel2Action = map[string]string{ + "suno_music": SunoActionMusic, + "suno_lyrics": SunoActionLyrics, +} diff --git a/controller/channel-test.go b/controller/channel-test.go index db03e75..0b8d442 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } + if channel.Type == common.ChannelTypeSunoAPI { + return errors.New("suno channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ diff --git a/controller/misc.go b/controller/misc.go index b8203f3..5e12854 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) { "display_in_currency": common.DisplayInCurrencyEnabled, "enable_batch_update": common.BatchUpdateEnabled, "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, "enable_data_export": common.DataExportEnabled, "data_export_default_time": common.DataExportDefaultTime, "default_collapse_sidebar": common.DefaultCollapseSidebar, diff --git a/controller/relay.go b/controller/relay.go index a066e5d..e7b8198 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) { "error": err, }) } + +func RelayTask(c *gin.Context) { + retryTimes := common.RetryTimes + channelId := c.GetInt("channel_id") + relayMode := c.GetInt("relay_mode") + group := c.GetString("group") + originalModel := c.GetString("original_model") + c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) + taskErr := taskRelayHandler(c, relayMode) + if taskErr == nil { + retryTimes = 0 + } + for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + if err != nil { + common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + break + } + channelId = channel.Id + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) + common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + taskErr = taskRelayHandler(c, relayMode) + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c.Request.Context(), retryLogStr) + } + if taskErr != nil { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(taskErr.StatusCode, taskErr) + } +} + +func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { + var err *dto.TaskError + switch relayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + err = relay.RelayTaskFetch(c, relayMode) + default: + err = relay.RelayTaskSubmit(c, relayMode) + } + return err +} + +func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { + if taskErr == nil { + return false + } + if retryTimes <= 0 { + return false + } + if _, ok := c.Get("specific_channel_id"); ok { + return false + } + if taskErr.StatusCode == http.StatusTooManyRequests { + return true + } + if taskErr.StatusCode == 307 { + return true + } + if taskErr.StatusCode/100 == 5 { + // 超时不重试 + if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 { + return false + } + return true + } + if taskErr.StatusCode == http.StatusBadRequest { + return false + } + if taskErr.StatusCode == 408 { + // azure处理超时不重试 + return false + } + if taskErr.LocalError { + return false + } + if taskErr.StatusCode/100 == 2 { + return false + } + return true +} diff --git a/controller/task.go b/controller/task.go new file mode 100644 index 0000000..fce9e7f --- /dev/null +++ b/controller/task.go @@ -0,0 +1,284 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/samber/lo" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/model" + "one-api/relay" + "sort" + "strconv" + "time" +) + +func UpdateTaskBulk() { + //revocer + //imageModel := "midjourney" + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + ctx := context.TODO() + allTasks := model.GetAllUnFinishSyncTasks(500) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + if task.TaskID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[task.TaskID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + UpdateTaskByPlatform(platform, taskChannelM, taskM) + } + common.SysLog("任务进度轮询完成") + } +} + +func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) + case constant.TaskPlatformSuno: + _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + default: + common.SysLog("未知平台") + } +} + +func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) + if err != nil { + common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + channel, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + err = model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + } + return err + } + adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) + if adaptor == nil { + return errors.New("adaptor not found") + } + resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ + "ids": taskIds, + }) + if err != nil { + common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = json.Unmarshal(responseBody, &responseItems) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !checkTaskNeedUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + err = model.CacheUpdateUserQuota(task.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } else { + quota := task.Quota + if quota != 0 { + err = model.IncreaseUserQuota(task.UserId, quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + } + } + return nil +} + +func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := json.Marshal(oldTask.Data) + newData, _ := json.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +func GetAllTask(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + // 解析其他查询参数 + queryParams := model.SyncTaskQueryParams{ + Platform: constant.TaskPlatform(c.Query("platform")), + TaskID: c.Query("task_id"), + Status: c.Query("status"), + Action: c.Query("action"), + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, + } + + logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + if logs == nil { + logs = make([]*model.Task, 0) + } + + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func GetUserTask(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + userId := c.GetInt("id") + + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + + queryParams := model.SyncTaskQueryParams{ + Platform: constant.TaskPlatform(c.Query("platform")), + TaskID: c.Query("task_id"), + Status: c.Query("status"), + Action: c.Query("action"), + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, + } + + logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + if logs == nil { + logs = make([]*model.Task, 0) + } + + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} diff --git a/dto/suno.go b/dto/suno.go new file mode 100644 index 0000000..a6bb3eb --- /dev/null +++ b/dto/suno.go @@ -0,0 +1,129 @@ +package dto + +import ( + "encoding/json" +) + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +type SunoSubmitReq struct { + GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` + Prompt string `json:"prompt,omitempty"` + Mv string `json:"mv,omitempty"` + Title string `json:"title,omitempty"` + Tags string `json:"tags,omitempty"` + ContinueAt float64 `json:"continue_at,omitempty"` + TaskID string `json:"task_id,omitempty"` + ContinueClipId string `json:"continue_clip_id,omitempty"` + MakeInstrumental bool `json:"make_instrumental"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} + +type SunoDataResponse struct { + TaskID string `json:"task_id" gorm:"type:varchar(50);index"` + Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode + Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time" gorm:"index"` + StartTime int64 `json:"start_time" gorm:"index"` + FinishTime int64 `json:"finish_time" gorm:"index"` + Data json.RawMessage `json:"data" gorm:"type:json"` +} + +type SunoSong struct { + ID string `json:"id"` + VideoURL string `json:"video_url"` + AudioURL string `json:"audio_url"` + ImageURL string `json:"image_url"` + ImageLargeURL string `json:"image_large_url"` + MajorModelVersion string `json:"major_model_version"` + ModelName string `json:"model_name"` + Status string `json:"status"` + Title string `json:"title"` + Text string `json:"text"` + Metadata SunoMetadata `json:"metadata"` +} + +type SunoMetadata struct { + Tags string `json:"tags"` + Prompt string `json:"prompt"` + GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"` + AudioPromptID interface{} `json:"audio_prompt_id"` + Duration interface{} `json:"duration"` + ErrorType interface{} `json:"error_type"` + ErrorMessage interface{} `json:"error_message"` +} + +type SunoLyrics struct { + ID string `json:"id"` + Status string `json:"status"` + Title string `json:"title"` + Text string `json:"text"` +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id + Action string `json:"action"` // 任务类型, song, lyrics, description-mode + Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Data json.RawMessage `json:"data"` +} + +type SunoGoAPISubmitReq struct { + CustomMode bool `json:"custom_mode"` + + Input SunoGoAPISubmitReqInput `json:"input"` + + NotifyHook string `json:"notify_hook,omitempty"` +} + +type SunoGoAPISubmitReqInput struct { + GptDescriptionPrompt string `json:"gpt_description_prompt"` + Prompt string `json:"prompt"` + Mv string `json:"mv"` + Title string `json:"title"` + Tags string `json:"tags"` + ContinueAt float64 `json:"continue_at"` + TaskID string `json:"task_id"` + ContinueClipId string `json:"continue_clip_id"` + MakeInstrumental bool `json:"make_instrumental"` +} + +type GoAPITaskResponse[T any] struct { + Code int `json:"code"` + Message string `json:"message"` + Data T `json:"data"` + ErrorMessage string `json:"error_message,omitempty"` +} + +type GoAPITaskResponseData struct { + TaskID string `json:"task_id"` +} + +type GoAPIFetchResponseData struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + Input string `json:"input"` + Clips map[string]SunoSong `json:"clips"` +} diff --git a/dto/task.go b/dto/task.go new file mode 100644 index 0000000..afc186b --- /dev/null +++ b/dto/task.go @@ -0,0 +1,10 @@ +package dto + +type TaskError struct { + Code string `json:"code"` + Message string `json:"message"` + Data any `json:"data"` + StatusCode int `json:"-"` + LocalError bool `json:"-"` + Error error `json:"-"` +} diff --git a/main.go b/main.go index 37c6a0a..006c118 100644 --- a/main.go +++ b/main.go @@ -20,10 +20,10 @@ import ( _ "net/http/pprof" ) -//go:embed web/dist +// /go:embed web/dist var buildFS embed.FS -//go:embed web/dist/index.html +// /go:embed web/dist/index.html var indexPage []byte func main() { @@ -92,6 +92,9 @@ func main() { common.SafeGoroutine(func() { controller.UpdateMidjourneyTaskBulk() }) + common.SafeGoroutine(func() { + controller.UpdateTaskBulk() + }) if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") diff --git a/middleware/distributor.go b/middleware/distributor.go index ae5707f..4862a48 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelRequest.Model = midjourneyModel } c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/suno/") { + relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeSunoFetch || + relayMode == relayconstant.RelayModeSunoFetchByID { + shouldSelectChannel = false + } else { + modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) + modelRequest.Model = modelName + } + c.Set("platform", string(constant.TaskPlatformSuno)) + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } diff --git a/model/main.go b/model/main.go index b6ad2cb..710ea05 100644 --- a/model/main.go +++ b/model/main.go @@ -140,6 +140,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Task{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/option.go b/model/option.go index 6aa59cb..45aa524 100644 --- a/model/option.go +++ b/model/option.go @@ -41,6 +41,7 @@ func InitOptionMap() { common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) @@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) { common.DisplayTokenStatEnabled = boolValue case "DrawingEnabled": common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue case "DataExportEnabled": common.DataExportEnabled = boolValue case "DefaultCollapseSidebar": diff --git a/model/task.go b/model/task.go new file mode 100644 index 0000000..df221ed --- /dev/null +++ b/model/task.go @@ -0,0 +1,304 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "one-api/constant" + commonRelay "one-api/relay/common" + "time" +) + +type TaskStatus string + +const ( + TaskStatusNotStart TaskStatus = "NOT_START" + TaskStatusSubmitted = "SUBMITTED" + TaskStatusQueued = "QUEUED" + TaskStatusInProgress = "IN_PROGRESS" + TaskStatusFailure = "FAILURE" + TaskStatusSuccess = "SUCCESS" + TaskStatusUnknown = "UNKNOWN" +) + +type Task struct { + ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"` + CreatedAt int64 `json:"created_at" gorm:"index"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id + Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台 + UserId int `json:"user_id" gorm:"index"` + ChannelId int `json:"channel_id" gorm:"index"` + Quota int `json:"quota"` + Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode + Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态 + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time" gorm:"index"` + StartTime int64 `json:"start_time" gorm:"index"` + FinishTime int64 `json:"finish_time" gorm:"index"` + Progress string `json:"progress" gorm:"type:varchar(20);index"` + Properties Properties `json:"properties" gorm:"type:json"` + + Data json.RawMessage `json:"data" gorm:"type:json"` +} + +func (t *Task) SetData(data any) { + b, _ := json.Marshal(data) + t.Data = json.RawMessage(b) +} + +func (t *Task) GetData(v any) error { + err := json.Unmarshal(t.Data, &v) + return err +} + +type Properties struct { + Input string `json:"input"` +} + +func (m *Properties) Scan(val interface{}) error { + bytesValue, _ := val.([]byte) + return json.Unmarshal(bytesValue, m) +} + +func (m Properties) Value() (driver.Value, error) { + return json.Marshal(m) +} + +// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 +type SyncTaskQueryParams struct { + Platform constant.TaskPlatform + ChannelID string + TaskID string + UserID string + Action string + Status string + StartTimestamp int64 + EndTimestamp int64 + UserIDs []int +} + +func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { + t := &Task{ + UserId: relayInfo.UserId, + SubmitTime: time.Now().Unix(), + Status: TaskStatusNotStart, + Progress: "0%", + ChannelId: relayInfo.ChannelId, + Platform: platform, + } + return t +} + +func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { + var tasks []*Task + var err error + + // 初始化查询构建器 + query := DB.Where("user_id = ?", userId) + + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.StartTimestamp != 0 { + // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + + // 获取数据 + err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + + return tasks +} + +func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { + var tasks []*Task + var err error + + // 初始化查询构建器 + query := DB + + // 添加过滤条件 + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.UserID != "" { + query = query.Where("user_id = ?", queryParams.UserID) + } + if len(queryParams.UserIDs) != 0 { + query = query.Where("user_id in (?)", queryParams.UserIDs) + } + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + + // 获取数据 + err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + + return tasks +} + +func GetAllUnFinishSyncTasks(limit int) []*Task { + var tasks []*Task + var err error + // get all tasks progress is not 100% + err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetByOnlyTaskId(taskId string) (*Task, bool, error) { + if taskId == "" { + return nil, false, nil + } + var task *Task + var err error + err = DB.Where("task_id = ?", taskId).First(&task).Error + exist, err := RecordExist(err) + if err != nil { + return nil, false, err + } + return task, exist, err +} + +func GetByTaskId(userId int, taskId string) (*Task, bool, error) { + if taskId == "" { + return nil, false, nil + } + var task *Task + var err error + err = DB.Where("user_id = ? and task_id = ?", userId, taskId). + First(&task).Error + exist, err := RecordExist(err) + if err != nil { + return nil, false, err + } + return task, exist, err +} + +func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { + if len(taskIds) == 0 { + return nil, nil + } + var task []*Task + var err error + err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds). + Find(&task).Error + if err != nil { + return nil, err + } + return task, nil +} + +func TaskUpdateProgress(id int64, progress string) error { + return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error +} + +func (Task *Task) Insert() error { + var err error + err = DB.Create(Task).Error + return err +} + +func (Task *Task) Update() error { + var err error + err = DB.Save(Task).Error + return err +} + +func TaskBulkUpdate(TaskIds []string, params map[string]any) error { + if len(TaskIds) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("task_id in (?)", TaskIds). + Updates(params).Error +} + +func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { + if len(taskIDs) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("id in (?)", taskIDs). + Updates(params).Error +} + +func TaskBulkUpdateByID(ids []int64, params map[string]any) error { + if len(ids) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("id in (?)", ids). + Updates(params).Error +} + +type TaskQuotaUsage struct { + Mode string `json:"mode"` + Count float64 `json:"count"` +} + +func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { + query := DB.Model(Task{}) + // 添加过滤条件 + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.UserID != "" { + query = query.Where("user_id = ?", queryParams.UserID) + } + if len(queryParams.UserIDs) != 0 { + query = query.Where("user_id in (?)", queryParams.UserIDs) + } + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error + return stat, err +} diff --git a/model/utils.go b/model/utils.go index 1c28340..44bfbb9 100644 --- a/model/utils.go +++ b/model/utils.go @@ -1,6 +1,8 @@ package model import ( + "errors" + "gorm.io/gorm" "one-api/common" "sync" "time" @@ -75,3 +77,13 @@ func batchUpdate() { } common.SysLog("batch update finished") } + +func RecordExist(err error) (bool, error) { + if err == nil { + return true, nil + } + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err +} diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d3886d5..d87f476 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -19,3 +19,22 @@ type Adaptor interface { GetModelList() []string GetChannelName() string } + +type TaskAdaptor interface { + Init(info *relaycommon.TaskRelayInfo) + + ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError + + BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) + BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error + BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) + + DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError) + + GetModelList() []string + GetChannelName() string + + // FetchTask + FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) +} diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ef82645..ab1131f 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { _ = c.Request.Body.Close() return resp, nil } + +func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.BuildRequestURL(info) + if err != nil { + return nil, err + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(requestBody), nil + } + + err = a.BuildRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go new file mode 100644 index 0000000..03d6051 --- /dev/null +++ b/relay/channel/task/suno/adaptor.go @@ -0,0 +1,172 @@ +package suno + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +type TaskAdaptor struct { + ChannelType int +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + action := strings.ToUpper(c.Param("action")) + + var sunoRequest *dto.SunoSubmitReq + err := common.UnmarshalBodyReusable(c, &sunoRequest) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + err = actionValidate(c, sunoRequest, action) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + + if sunoRequest.ContinueClipId != "" { + if sunoRequest.TaskID == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest) + return + } + info.OriginTaskID = sunoRequest.TaskID + } + + info.Action = action + c.Set("task_request", sunoRequest) + return nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + baseURL := info.BaseUrl + fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) + return fullRequestURL, nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + sunoRequest, ok := c.Get("task_request") + if !ok { + err := common.UnmarshalBodyReusable(c, &sunoRequest) + if err != nil { + return nil, err + } + } + data, err := json.Marshal(sunoRequest) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + var sunoResponse dto.TaskResponse[string] + err = json.Unmarshal(responseBody, &sunoResponse) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + if !sunoResponse.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError) + return + } + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return + } + + return sunoResponse.Data, nil, nil +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) + byteBody, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) + if err != nil { + common.SysError(fmt.Sprintf("Get Task error: %v", err)) + return nil, err + } + defer req.Body.Close() + // 设置超时时间 + timeout := time.Second * 15 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+key) + resp, err := service.GetHttpClient().Do(req) + if err != nil { + return nil, err + } + return resp, nil +} + +func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) { + switch action { + case constant.SunoActionMusic: + if sunoRequest.Mv == "" { + sunoRequest.Mv = "chirp-v3-0" + } + case constant.SunoActionLyrics: + if sunoRequest.Prompt == "" { + err = fmt.Errorf("prompt_empty") + return + } + default: + err = fmt.Errorf("invalid_action") + } + return +} diff --git a/relay/channel/task/suno/models.go b/relay/channel/task/suno/models.go new file mode 100644 index 0000000..967cf1b --- /dev/null +++ b/relay/channel/task/suno/models.go @@ -0,0 +1,7 @@ +package suno + +var ModelList = []string{ + "suno_music", "suno_lyrics", +} + +var ChannelName = "suno" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b40352e..f93d36a 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) { func (info *RelayInfo) SetIsStream(isStream bool) { info.IsStream = isStream } + +type TaskRelayInfo struct { + ChannelType int + ChannelId int + TokenId int + UserId int + Group string + StartTime time.Time + ApiType int + RelayMode int + UpstreamModelName string + RequestURLPath string + ApiKey string + BaseUrl string + + Action string + OriginTaskID string + + ConsumeQuota bool +} + +func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + startTime := time.Now() + + apiType, _ := constant.ChannelType2APIType(channelType) + + info := &TaskRelayInfo{ + RelayMode: constant.Path2RelayMode(c.Request.URL.Path), + BaseUrl: c.GetString("base_url"), + RequestURLPath: c.Request.URL.String(), + ChannelType: channelType, + ChannelId: channelId, + TokenId: tokenId, + UserId: userId, + Group: group, + StartTime: startTime, + ApiType: apiType, + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + } + if info.BaseUrl == "" { + info.BaseUrl = common.ChannelBaseURLs[channelType] + } + return info +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 2e94bc0..fa19f50 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -1,6 +1,9 @@ package constant -import "strings" +import ( + "net/http" + "strings" +) const ( RelayModeUnknown = iota @@ -26,6 +29,9 @@ const ( RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + RelayModeSunoFetch + RelayModeSunoFetchByID + RelayModeSunoSubmit ) func Path2RelayMode(path string) int { @@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int { } return relayMode } + +func Path2RelaySuno(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/fetch") { + relayMode = RelayModeSunoFetch + } else if method == http.MethodGet && strings.Contains(path, "/fetch/") { + relayMode = RelayModeSunoFetchByID + } else if strings.Contains(path, "/submit/") { + relayMode = RelayModeSunoSubmit + } + return relayMode +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index cf63054..bfa13f4 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,6 +1,7 @@ package relay import ( + commonconstant "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" "one-api/relay/channel/aws" @@ -12,6 +13,7 @@ import ( "one-api/relay/channel/openai" "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" + "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" @@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor { } return nil } + +func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { + switch platform { + //case constant.APITypeAIProxyLibrary: + // return &aiproxy.Adaptor{} + case commonconstant.TaskPlatformSuno: + return &suno.TaskAdaptor{} + } + return nil +} diff --git a/relay/relay_task.go b/relay/relay_task.go new file mode 100644 index 0000000..47d8a5c --- /dev/null +++ b/relay/relay_task.go @@ -0,0 +1,242 @@ +package relay + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/model" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" +) + +/* +Task 任务通过平台、Action 区分任务 +*/ +func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { + platform := constant.TaskPlatform(c.GetString("platform")) + relayInfo := relaycommon.GenTaskRelayInfo(c) + + adaptor := GetTaskAdaptor(platform) + if adaptor == nil { + return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + // get & validate taskRequest 获取并验证文本请求 + taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) + if taskErr != nil { + return + } + + modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) + modelPrice, success := common.GetModelPrice(modelName, true) + if !success { + defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + + // 预扣 + groupRatio := common.GetGroupRatio(relayInfo.Group) + ratio := modelPrice * groupRatio + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return + } + quota := int(ratio * common.QuotaPerUnit) + if userQuota-quota < 0 { + taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) + return + } + + if relayInfo.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + if originTask.ChannelId != relayInfo.ChannelId { + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + return + } + if channel.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest) + } + c.Set("base_url", channel.GetBaseURL()) + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + relayInfo.BaseUrl = channel.GetBaseURL() + relayInfo.ChannelId = originTask.ChannelId + } + } + + // build body + requestBody, err := adaptor.BuildRequestBody(c, relayInfo) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) + return + } + // do request + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return + } + // handle response + if resp != nil && resp.StatusCode != http.StatusOK { + responseBody, _ := io.ReadAll(resp.Body) + taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode) + return + } + + defer func(ctx context.Context) { + // release quota + if relayInfo.ConsumeQuota && taskErr == nil { + err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) + other := make(map[string]interface{}) + other["model_price"] = modelPrice + other["group_ratio"] = groupRatio + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + } + }(c.Request.Context()) + + taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) + if taskErr != nil { + return + } + relayInfo.ConsumeQuota = true + // insert task + task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task.TaskID = taskID + task.Quota = quota + task.Data = taskData + err = task.Insert() + if err != nil { + taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) + return + } + return nil +} + +var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, +} + +func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { + respBuilder, ok := fetchRespBuilders[relayMode] + if !ok { + taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) + } + + respBody, taskErr := respBuilder(c) + if taskErr != nil { + return taskErr + } + + c.Writer.Header().Set("Content-Type", "application/json") + _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return + } + return +} + +func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + userId := c.GetInt("id") + var condition = struct { + IDs []any `json:"ids"` + Action string `json:"action"` + }{} + err := c.BindJSON(&condition) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) + return + } + var tasks []any + if len(condition.IDs) > 0 { + taskModels, err := model.GetByTaskIds(userId, condition.IDs) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) + return + } + for _, task := range taskModels { + tasks = append(tasks, TaskModel2Dto(task)) + } + } else { + tasks = make([]any, 0) + } + respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + Code: "success", + Data: tasks, + }) + return +} + +func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + +func TaskModel2Dto(task *model.Task) *dto.TaskDto { + return &dto.TaskDto{ + TaskID: task.TaskID, + Action: task.Action, + Status: string(task.Status), + FailReason: task.FailReason, + SubmitTime: task.SubmitTime, + StartTime: task.StartTime, + FinishTime: task.FinishTime, + Progress: task.Progress, + Data: task.Data, + } +} diff --git a/router/api-router.go b/router/api-router.go index 7657a98..6807939 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) { mjRoute := apiRouter.Group("/mj") mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) + + taskRoute := apiRouter.Group("/task") + { + taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask) + taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask) + } } } diff --git a/router/relay-router.go b/router/relay-router.go index 2d8e7b3..3ad9e37 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) { relayMjModeRouter := router.Group("/:mode/mj") registerMjRouterGroup(relayMjModeRouter) //relayMjRouter.Use() + + relaySunoRouter := router.Group("/suno") + relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) + { + relaySunoRouter.POST("/submit/:action", controller.RelayTask) + relaySunoRouter.POST("/fetch", controller.RelayTask) + relaySunoRouter.GET("/fetch/:id", controller.RelayTask) + } + } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { diff --git a/service/error.go b/service/error.go index 4b00f37..0f6d472 100644 --- a/service/error.go +++ b/service/error.go @@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping openaiErr.StatusCode = intCode } } + +func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError { + openaiErr := TaskErrorWrapper(err, code, statusCode) + openaiErr.LocalError = true + return openaiErr +} + +func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { + text := err.Error() + + // 定义一个正则表达式匹配URL + if strings.Contains(text, "Post") || strings.Contains(text, "dial") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } + //避免暴露内部错误 + + taskError := &dto.TaskError{ + Code: code, + Message: text, + StatusCode: statusCode, + Error: err, + } + + return taskError +} diff --git a/service/task.go b/service/task.go new file mode 100644 index 0000000..c2501fe --- /dev/null +++ b/service/task.go @@ -0,0 +1,10 @@ +package service + +import ( + "one-api/constant" + "strings" +) + +func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string { + return strings.ToLower(string(platform)) + "_" + strings.ToLower(action) +} diff --git a/web/src/App.js b/web/src/App.js index 1b63def..0db9a22 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -23,6 +23,7 @@ import Chat from './pages/Chat'; import { Layout } from '@douyinfe/semi-ui'; import Midjourney from './pages/Midjourney'; import Pricing from './pages/Pricing/index.js'; +import Task from "./pages/Task/index.js"; // import Detail from './pages/Detail'; const Home = lazy(() => import('./pages/Home')); @@ -220,6 +221,16 @@ function App() { } /> + + }> + + + + } + /> { chat: '/chat', detail: '/detail', pricing: '/pricing', + task: '/task', }; const headerButtons = useMemo( @@ -142,6 +143,16 @@ const SiderBar = () => { ? 'semi-navigation-item-normal' : 'tableHiddle', }, + { + text: '异步任务', + itemKey: 'task', + to: '/task', + icon: , + className: + localStorage.getItem('enable_task') === 'true' + ? 'semi-navigation-item-normal' + : 'tableHiddle', + }, { text: '设置', itemKey: 'setting', @@ -158,6 +169,7 @@ const SiderBar = () => { [ localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), + localStorage.getItem('enable_task'), localStorage.getItem('chat_link'), isAdmin(), ], diff --git a/web/src/components/TaskLogsTable.js b/web/src/components/TaskLogsTable.js new file mode 100644 index 0000000..52bf39b --- /dev/null +++ b/web/src/components/TaskLogsTable.js @@ -0,0 +1,400 @@ +import React, { useEffect, useState } from 'react'; +import { Label } from 'semantic-ui-react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { + Table, + Tag, + Form, + Button, + Layout, + Modal, + Typography, Progress, Card +} from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +] + + +const renderTimestamp = (timestampInSeconds) => { + const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒 + + const year = date.getFullYear(); // 获取年份 + const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数 + const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数 + const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数 + const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数 + const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数 + + return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出 +}; + +function renderDuration(submit_time, finishTime) { + // 确保startTime和finishTime都是有效的时间戳 + if (!submit_time || !finishTime) return 'N/A'; + + // 将时间戳转换为Date对象 + const start = new Date(submit_time); + const finish = new Date(finishTime); + + // 计算时间差(毫秒) + const durationMs = finish - start; + + // 将时间差转换为秒,并保留一位小数 + const durationSec = (durationMs / 1000).toFixed(1); + + // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色 + const color = durationSec > 60 ? 'red' : 'green'; + + // 返回带有样式的颜色标签 + return ( + + {durationSec} 秒 + + ); +} + +const LogsTable = () => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [modalContent, setModalContent] = useState(''); + const isAdminUser = isAdmin(); + const columns = [ + { + title: "提交时间", + dataIndex: 'submit_time', + render: (text, record, index) => { + return ( + + {text ? renderTimestamp(text) : "-"} + + ); + }, + }, + { + title: "结束时间", + dataIndex: 'finish_time', + render: (text, record, index) => { + return ( + + {text ? renderTimestamp(text) : "-"} + + ); + }, + }, + { + title: '进度', + dataIndex: 'progress', + width: 50, + render: (text, record, index) => { + return ( + + { + // 转换例如100%为数字100,如果text未定义,返回0 + isNaN(text.replace('%', '')) ? text : + } + + ); + }, + }, + { + title: '花费时间', + dataIndex: 'finish_time', // 以finish_time作为dataIndex + key: 'finish_time', + render: (finish, record) => { + // 假设record.start_time是存在的,并且finish是完成时间的时间戳 + return <> + { + finish ? renderDuration(record.submit_time, finish) : "-" + } + > + }, + }, + { + title: "渠道", + dataIndex: 'channel_id', + className: isAdminUser ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( + + { + copyText(text); // 假设copyText是用于文本复制的函数 + }} + > + {' '} + {text}{' '} + + + ); + }, + }, + { + title: "平台", + dataIndex: 'platform', + render: (text, record, index) => { + return ( + + {renderPlatform(text)} + + ); + }, + }, + { + title: '类型', + dataIndex: 'action', + render: (text, record, index) => { + return ( + + {renderType(text)} + + ); + }, + }, + { + title: '任务ID(点击查看详情)', + dataIndex: 'task_id', + render: (text, record, index) => { + return ( { + setModalContent(JSON.stringify(record, null, 2)); + setIsModalOpen(true); + }} + > + + {text} + + ); + }, + }, + { + title: '任务状态', + dataIndex: 'status', + render: (text, record, index) => { + return ( + + {renderStatus(text)} + + ); + }, + }, + + { + title: '失败原因', + dataIndex: 'fail_reason', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + } + ]; + + const [logs, setLogs] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logType] = useState(0); + + let now = new Date(); + // 初始化start_timestamp为前一天 + let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate()); + const [inputs, setInputs] = useState({ + channel_id: '', + task_id: '', + start_timestamp: timestamp2string(zeroNow.getTime() /1000), + end_timestamp: '', + }); + const { channel_id, task_id, start_timestamp, end_timestamp } = inputs; + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + } + + const loadLogs = async (startIdx) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000); + let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 ); + if (isAdminUser) { + url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + let { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1).then(r => { + }); + } + }; + + const refresh = async () => { + // setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text }); + } + } + + useEffect(() => { + refresh().then(); + }, [logType]); + + const renderType = (type) => { + switch (type) { + case 'MUSIC': + return 生成音乐 ; + case 'LYRICS': + return 生成歌词 ; + + default: + return 未知 ; + } + } + + const renderPlatform = (type) => { + switch (type) { + case "suno": + return Suno ; + default: + return 未知 ; + } + } + + const renderStatus = (type) => { + switch (type) { + case 'SUCCESS': + return 成功 ; + case 'NOT_START': + return 未启动 ; + case 'SUBMITTED': + return 队列中 ; + case 'IN_PROGRESS': + return 执行中 ; + case 'FAILURE': + return 失败 ; + case 'QUEUED': + return 排队中 ; + case 'UNKNOWN': + return 未知 ; + case '': + return 正在提交 ; + default: + return 未知 ; + } + } + + return ( + <> + + + + <> + {isAdminUser && handleInputChange(value, 'channel_id')} /> + } + handleInputChange(value, 'task_id')} /> + + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + 查询 + > + + + + + setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式 + width={800} // 设置模态框宽度 + > + {modalContent} + + + > + ); +}; + +export default LogsTable; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 9b10f0f..e67dbc6 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Midjourney Proxy Plus', }, + { + key: 36, + text: 'Suno API', + value: 36, + color: 'purple', + label: 'Suno API', + }, { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' }, { key: 14, diff --git a/web/src/helpers/data.js b/web/src/helpers/data.js index 750b670..93380b4 100644 --- a/web/src/helpers/data.js +++ b/web/src/helpers/data.js @@ -6,6 +6,7 @@ export function setStatusData(data) { localStorage.setItem('quota_per_unit', data.quota_per_unit); localStorage.setItem('display_in_currency', data.display_in_currency); localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_task', data.enable_task); localStorage.setItem('enable_data_export', data.enable_data_export); localStorage.setItem( 'data_export_default_time', diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 821a056..08f4a63 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -126,6 +126,12 @@ const EditChannel = (props) => { 'mj_uploads', ]; break; + case 36: + localModels = [ + 'suno_music', + 'suno_lyrics', + ]; + break; default: localModels = getChannelModels(value); break; @@ -513,12 +519,32 @@ const EditChannel = (props) => { /> > )} - + {inputs.type === 36 && ( + <> + + + 注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用 + + + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + > + )} + 名称: { handleInputChange('name', value); @@ -758,7 +784,7 @@ const EditChannel = (props) => { )} - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( + {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && ( <> 代理: diff --git a/web/src/pages/Task/index.js b/web/src/pages/Task/index.js new file mode 100644 index 0000000..aec3702 --- /dev/null +++ b/web/src/pages/Task/index.js @@ -0,0 +1,10 @@ +import React from 'react'; +import TaskLogsTable from "../../components/TaskLogsTable.js"; + +const Task = () => ( + <> + + > +); + +export default Task;
{modalContent}