From 3d10c9f090300c5653823556ca224a5d86cb86e7 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 21:19:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=B0=86=E6=93=8D=E4=BD=9C=E6=8B=86?= =?UTF-8?q?=E5=88=86=E6=88=90=E5=8D=95=E7=8B=AC=E7=9A=84=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 32 +++-- constant/midjourney.go | 18 +++ controller/midjourney.go | 131 ------------------ controller/model.go | 18 ++- controller/relay.go | 12 +- dto/midjourney.go | 7 + middleware/auth.go | 8 +- middleware/distributor.go | 129 +++++++++++------- middleware/utils.go | 12 +- relay/relay-mj.go | 195 +++++++-------------------- service/midjourney.go | 135 +++++++++++++++++++ web/src/pages/Channel/EditChannel.js | 19 ++- 12 files changed, 366 insertions(+), 350 deletions(-) create mode 100644 service/midjourney.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 791f733..153b748 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -94,17 +94,30 @@ var ModelRatio = map[string]float64{ "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } -var ModelPrice = map[string]float64{ - "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, +var DefaultModelPrice = map[string]float64{ + "gpt-4-gizmo-*": 0.1, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05, } +var ModelPrice = map[string]float64{} + func ModelPrice2JSONString() string { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } jsonBytes, err := json.Marshal(ModelPrice) if err != nil { SysError("error marshalling model price: " + err.Error()) @@ -118,6 +131,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error { } func GetModelPrice(name string, printErr bool) float64 { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } diff --git a/constant/midjourney.go b/constant/midjourney.go index 5435a43..92e2f23 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -11,6 +11,7 @@ const ( MjActionBlend = "BLEND" MjActionUpscale = "UPSCALE" MjActionVariation = "VARIATION" + MjActionReRoll = "REROLL" MjActionInPaint = "INPAINT" MjActionInPaintPre = "INPAINT_PRE" MjActionZoom = "ZOOM" @@ -20,3 +21,20 @@ const ( MjActionPan = "PAN" SwapFace = "SWAP_FACE" ) + +var MidjourneyModel2Action = map[string]string{ + "mj_imagine": MjActionImagine, + "mj_describe": MjActionDescribe, + "mj_blend": MjActionBlend, + "mj_upscale": MjActionUpscale, + "mj_variation": MjActionVariation, + "mj_reroll": MjActionReRoll, + "mj_inpaint": MjActionInPaint, + "mj_inpaint_pre": MjActionInPaintPre, + "mj_zoom": MjActionZoom, + "mj_shorten": MjActionShorten, + "mj_high_variation": MjActionHighVariation, + "mj_low_variation": MjActionLowVariation, + "mj_pan": MjActionPan, + "swap_face": SwapFace, +} diff --git a/controller/midjourney.go b/controller/midjourney.go index b666e91..6256471 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -18,137 +18,6 @@ import ( "time" ) -/*func UpdateMidjourneyTask() { - //revocer - //imageModel := "midjourney" - ctx := context.TODO() - imageModel := "midjourney" - defer func() { - if err := recover(); err != nil { - log.Printf("UpdateMidjourneyTask panic: %v", err) - } - }() - for { - time.Sleep(time.Duration(15) * time.Second) - tasks := model.GetAllUnFinishTasks() - if len(tasks) != 0 { - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) - for _, task := range tasks { - common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task)) - midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) - if err != nil { - common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err)) - task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId) - task.Status = "FAILURE" - task.Progress = "100%" - err := task.Update() - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) - continue - } - continue - } - requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) - common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl)) - - req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err)) - continue - } - - // 设置超时时间 - timeout := time.Second * 5 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) - - req.Header.Set("Content-Type", "application/json") - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") - req.Header.Set("mj-api-secret", midjourneyChannel.Key) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) - continue - } - responseBody, err := io.ReadAll(resp.Body) - resp.Body.Close() - log.Printf("responseBody: %s", string(responseBody)) - var responseItem MidjourneyDto - // err = json.NewDecoder(resp.Body).Decode(&responseItem) - err = json.Unmarshal(responseBody, &responseItem) - if err != nil { - if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") { - var responseWithoutStatus MidjourneyWithoutStatus - var responseStatus MidjourneyStatus - err1 := json.Unmarshal(responseBody, &responseWithoutStatus) - err2 := json.Unmarshal(responseBody, &responseStatus) - if err1 == nil && err2 == nil { - jsonData, err3 := json.Marshal(responseWithoutStatus) - if err3 != nil { - log.Printf("UpdateMidjourneyTask error1: %v", err3) - continue - } - err4 := json.Unmarshal(jsonData, &responseStatus) - if err4 != nil { - log.Printf("UpdateMidjourneyTask error2: %v", err4) - continue - } - responseItem.Status = strconv.Itoa(responseStatus.Status) - } else { - log.Printf("UpdateMidjourneyTask error3: %v", err) - continue - } - } else { - log.Printf("UpdateMidjourneyTask error4: %v", err) - continue - } - } - task.Code = 1 - task.Progress = responseItem.Progress - task.PromptEn = responseItem.PromptEn - task.State = responseItem.State - task.SubmitTime = responseItem.SubmitTime - task.StartTime = responseItem.StartTime - task.FinishTime = responseItem.FinishTime - task.ImageUrl = responseItem.ImageUrl - task.Status = responseItem.Status - task.FailReason = responseItem.FailReason - if task.Progress != "100%" && responseItem.FailReason != "" { - common.LogWarn(task.MjId + " 构建失败," + task.FailReason) - task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) - if err != nil { - log.Println("error update user quota cache: " + err.Error()) - } else { - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio("default") - ratio := modelRatio * groupRatio - quota := int(ratio * 1 * 1000) - if quota != 0 { - err := model.IncreaseUserQuota(task.UserId, quota) - if err != nil { - log.Println("fail to increase user quota") - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - - err = task.Update() - if err != nil { - log.Printf("UpdateMidjourneyTask error5: %v", err) - } - log.Printf("UpdateMidjourneyTask success: %v", task) - cancel() - } - } - } -} -*/ - func UpdateMidjourneyTaskBulk() { //imageModel := "midjourney" ctx := context.TODO() diff --git a/controller/model.go b/controller/model.go index 38c6c46..9a106aa 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" "one-api/relay/channel/moonshot" - "one-api/relay/constant" + relayconstant "one-api/relay/constant" ) // https://platform.openai.com/docs/api-reference/models/list @@ -59,8 +60,8 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < relayconstant.APITypeDummy; i++ { + if i == relayconstant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) @@ -100,6 +101,17 @@ func init() { Parent: nil, }) } + for modelName, _ := range constant.MidjourneyModel2Action { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "midjourney", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/controller/relay.go b/controller/relay.go index d35c6a2..fa5493a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { } func RelayMidjourney(c *gin.Context) { - relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path) + relayMode := c.GetInt("relay_mode") var err *dto.MidjourneyResponse switch relayMode { case relayconstant.RelayModeMidjourneyNotify: @@ -73,13 +73,15 @@ func RelayMidjourney(c *gin.Context) { //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { + statusCode := http.StatusBadRequest if err.Code == 30 { err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + statusCode = http.StatusTooManyRequests } - c.JSON(429, gin.H{ - "error": fmt.Sprintf("%s %s", err.Description, err.Result), - "type": "upstream_error", - "code": err.Code, + c.JSON(statusCode, gin.H{ + "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": "upstream_error", + "code": err.Code, }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) diff --git a/dto/midjourney.go b/dto/midjourney.go index d3b19d5..f81756c 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -1,5 +1,12 @@ package dto +//type SimpleMjRequest struct { +// Prompt string `json:"prompt"` +// CustomId string `json:"customId"` +// Action string `json:"action"` +// Content string `json:"content"` +//} + type MidjourneyRequest struct { Prompt string `json:"prompt"` CustomId string `json:"customId"` diff --git a/middleware/auth.go b/middleware/auth.go index a8dac30..4b865c2 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) { } token, err := model.ValidateUserToken(key) if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) + abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -129,7 +129,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 1ca43dd..c1b3ccb 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,7 +4,11 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + relayconstant "one-api/relay/constant" + "one-api/service" "strconv" "strings" @@ -23,32 +27,58 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") + abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { + shouldSelectChannel := true // Select a channel for the user var modelRequest ModelRequest var err error if strings.HasPrefix(c.Request.URL.Path, "/mj") { - // Midjourney - if modelRequest.Model == "" { - modelRequest.Model = "midjourney" + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -87,60 +117,61 @@ func Distribute() func(c *gin.Context) { } if tokenModelLimit != nil { if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) return } } else { // token model limit is empty, all models are not allowed - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) - - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) - // 如果错误,但是渠道不为空,说明是数据库一致性问题 - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" + if shouldSelectChannel { + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + // 如果错误,但是渠道不为空,说明是数据库一致性问题 + if channel != nil { + common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + message = "数据库一致性已被破坏,请联系管理员" + } + // 如果错误,而且渠道为空,说明是没有可用渠道 + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + return + } + if channel == nil { + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) + return + } + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + c.Set("auto_ban", ban) + c.Set("model_mapping", channel.GetModelMapping()) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 + switch channel.Type { + case common.ChannelTypeAzure: + c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } - // 如果错误,而且渠道为空,说明是没有可用渠道 - abortWithMessage(c, http.StatusServiceUnavailable, message) - return } - if channel == nil { - abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) - return - } - } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - c.Set("auto_ban", ban) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // TODO: api_version统一 - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) } c.Next() } diff --git a/middleware/utils.go b/middleware/utils.go index 021002d..43801c1 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -5,7 +5,7 @@ import ( "one-api/common" ) -func abortWithMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), @@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() common.LogError(c.Request.Context(), message) } + +func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { + c.JSON(statusCode, gin.H{ + "description": description, + "type": "new_api_error", + "code": code, + }) + c.Abort() + common.LogError(c.Request.Context(), description) +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f391f14..6cdd9e0 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -21,23 +21,6 @@ import ( "github.com/gin-gonic/gin" ) -var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_inpaint": 0.1, - "mj_zoom": 0.1, - "mj_shorten": 0.1, - "mj_high_variation": 0.1, - "mj_low_variation": 0.1, - "mj_pan": 0.1, - "mj_inpaint_pre": 0, - "mj_describe": 0.05, - "mj_upscale": 0.05, - "swap_face": 0.05, -} - func RelayMidjourneyImage(c *gin.Context) { taskId := c.Param("id") midjourneyTask := model.GetByOnlyMJId(taskId) @@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse } func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - imageModel := "midjourney" tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") + //channelType := c.GetInt("channel") userId := c.GetInt("id") group := c.GetString("group") channelId := c.GetInt("channel_id") @@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 - mjErr := coverPlusActionToNormalAction(&midjRequest) + mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } @@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } - params := convertSimpleChangeParams(midjRequest.Content) + params := service.ConvertSimpleChangeParams(midjRequest.Content) if params == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") } - mjId = params.ID + mjId = params.TaskId midjRequest.Action = params.Action } else if relayMode == relayconstant.RelayModeMidjourneyModal { if midjRequest.MaskBase64 == "" { @@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt - if channelType == common.ChannelTypeMidjourneyPlus { - // plus - } else { - // 普通版渠道 - - } + //if channelType == common.ChannelTypeMidjourneyPlus { + // // plus + //} else { + // // 普通版渠道 + // + //} } if midjRequest.Action == constant.MjActionInPaintPre { @@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_model_mapping_failed", - } - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } - } + //modelMapping := c.GetString("model_mapping") + //isModelMapped := false + //if modelMapping != "" { + // modelMap := make(map[string]string) + // err := json.Unmarshal([]byte(modelMapping), &modelMap) + // if err != nil { + // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + // return &dto.MidjourneyResponse{ + // Code: 4, + // Description: "unmarshal_model_mapping_failed", + // } + // } + // if modelMap[imageModel] != "" { + // imageModel = modelMap[imageModel] + // isModelMapped = true + // } + //} - baseURL := common.ChannelBaseURLs[channelType] + //baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } + baseURL := c.GetString("base_url") //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "marshal_text_request_failed", - } - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } + //if isModelMapped { + // jsonStr, err := json.Marshal(midjRequest) + // if err != nil { + // return &dto.MidjourneyResponse{ + // Code: 4, + // Description: "marshal_text_request_failed", + // } + // } + // requestBody = bytes.NewBuffer(jsonStr) + //} else { + //} + requestBody = c.Request.Body - mjAction := "mj_" + strings.ToLower(midjRequest.Action) - modelPrice := common.GetModelPrice(mjAction, true) + modelName := service.CoverActionToModelName(midjRequest.Action) + modelPrice := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if modelPrice == -1 { - defaultPrice, ok := DefaultModelPrice[mjAction] + defaultPrice, ok := common.DefaultModelPrice[modelName] if !ok { modelPrice = 0.1 } else { @@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) @@ -558,85 +541,3 @@ type taskChangeParams struct { Action string Index int } - -func convertSimpleChangeParams(content string) *taskChangeParams { - split := strings.Split(content, " ") - if len(split) != 2 { - return nil - } - - action := strings.ToLower(split[1]) - changeParams := &taskChangeParams{} - changeParams.ID = split[0] - - if action[0] == 'u' { - changeParams.Action = "UPSCALE" - } else if action[0] == 'v' { - changeParams.Action = "VARIATION" - } else if action == "r" { - changeParams.Action = "REROLL" - return changeParams - } else { - return nil - } - - index, err := strconv.Atoi(action[1:2]) - if err != nil || index < 1 || index > 4 { - return nil - } - changeParams.Index = index - return changeParams -} - -func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { - // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" - customId := midjRequest.CustomId - if customId == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") - } - splits := strings.Split(customId, "::") - var action string - if splits[1] == "JOB" { - action = splits[2] - } else { - action = splits[1] - } - - if action == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") - } - if strings.Contains(action, "upsample") { - index, err := strconv.Atoi(splits[3]) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") - } - midjRequest.Index = index - midjRequest.Action = constant.MjActionUpscale - } else if strings.Contains(action, "variation") { - midjRequest.Index = 1 - if action == "variation" { - index, err := strconv.Atoi(splits[3]) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") - } - midjRequest.Index = index - midjRequest.Action = constant.MjActionVariation - } else if action == "low_variation" { - midjRequest.Action = constant.MjActionLowVariation - } else if action == "high_variation" { - midjRequest.Action = constant.MjActionHighVariation - } - } else if strings.Contains(action, "pan") { - midjRequest.Action = constant.MjActionPan - midjRequest.Index = 1 - } else if action == "Outpaint" || action == "CustomZoom" { - midjRequest.Action = constant.MjActionZoom - midjRequest.Index = 1 - } else if action == "Inpaint" { - midjRequest.Action = constant.MjActionInPaintPre - midjRequest.Index = 1 - } else { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") - } - return nil -} diff --git a/service/midjourney.go b/service/midjourney.go new file mode 100644 index 0000000..06730c8 --- /dev/null +++ b/service/midjourney.go @@ -0,0 +1,135 @@ +package service + +import ( + "one-api/constant" + "one-api/dto" + relayconstant "one-api/relay/constant" + "strconv" + "strings" +) + +func CoverActionToModelName(mjAction string) string { + modelName := "mj_" + strings.ToLower(mjAction) + return modelName +} + +func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { + action := "" + if relayMode == relayconstant.RelayModeMidjourneyAction { + // plus request + err := CoverPlusActionToNormalAction(midjRequest) + if err != nil { + return "", err, false + } + action = midjRequest.Action + } else { + switch relayMode { + case relayconstant.RelayModeMidjourneyImagine: + action = constant.MjActionImagine + case relayconstant.RelayModeMidjourneyDescribe: + action = constant.MjActionDescribe + case relayconstant.RelayModeMidjourneyBlend: + action = constant.MjActionBlend + case relayconstant.RelayModeMidjourneyShorten: + action = constant.MjActionShorten + case relayconstant.RelayModeMidjourneyChange: + action = midjRequest.Action + case relayconstant.RelayModeMidjourneyModal: + action = constant.MjActionInPaint + case relayconstant.RelayModeMidjourneySimpleChange: + params := ConvertSimpleChangeParams(midjRequest.Content) + if params == nil { + return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false + } + action = params.Action + case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: + return "", nil, true + default: + return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false + } + } + modelName := CoverActionToModelName(action) + return modelName, nil, true +} + +func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Index = 1 + if action == "variation" { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionVariation + } else if action == "low_variation" { + midjRequest.Action = constant.MjActionLowVariation + } else if action == "high_variation" { + midjRequest.Action = constant.MjActionHighVariation + } + } else if strings.Contains(action, "pan") { + midjRequest.Action = constant.MjActionPan + midjRequest.Index = 1 + } else if action == "Outpaint" || action == "CustomZoom" { + midjRequest.Action = constant.MjActionZoom + midjRequest.Index = 1 + } else if action == "Inpaint" { + midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Index = 1 + } else { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + return nil +} + +func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { + split := strings.Split(content, " ") + if len(split) != 2 { + return nil + } + + action := strings.ToLower(split[1]) + changeParams := &dto.MidjourneyRequest{} + changeParams.TaskId = split[0] + + if action[0] == 'u' { + changeParams.Action = "UPSCALE" + } else if action[0] == 'v' { + changeParams.Action = "VARIATION" + } else if action == "r" { + changeParams.Action = "REROLL" + return changeParams + } else { + return nil + } + + index, err := strconv.Atoi(action[1:2]) + if err != nil || index < 1 || index > 4 { + return nil + } + changeParams.Index = index + return changeParams +} diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index ee79368..225ce3f 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -96,10 +96,25 @@ const EditChannel = (props) => { localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; break; case 2: - localModels = ['midjourney']; + localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe']; break; case 5: - localModels = ['midjourney']; + localModels = [ + 'swap_face', + 'mj_imagine', + 'mj_variation', + 'mj_reroll', + 'mj_blend', + 'mj_upscale', + 'mj_describe', + 'mj_zoom', + 'mj_shorten', + 'mj_inpaint_pre', + 'mj_inpaint_pre', + 'mj_high_variation', + 'mj_low_variation', + 'mj_pan', + ]; break; } setInputs((inputs) => ({...inputs, models: localModels}));