diff --git a/middleware/distributor.go b/middleware/distributor.go index 35cb6df..e922662 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") + modelRequest, shouldSelectChannel, err := getModelRequest(c) + userGroup, _ := model.CacheGetUserGroup(userId) + c.Set("group", userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) { return } } else { - shouldSelectChannel := true // Select a channel for the user - var modelRequest ModelRequest - var err error - if strings.Contains(c.Request.URL.Path, "/mj/") { - relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) - if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || - relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || - relayMode == relayconstant.RelayModeMidjourneyNotify || - relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { - shouldSelectChannel = false - } else { - midjourneyRequest := dto.MidjourneyRequest{} - err = common.UnmarshalBodyReusable(c, &midjourneyRequest) - if err != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) - return - } - midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) - if mjErr != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) - return - } - if midjourneyModel == "" { - if !success { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") - return - } else { - // task fetch, task fetch by condition, notify - shouldSelectChannel = false - } - } - modelRequest.Model = midjourneyModel - } - c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } - if err != nil { - abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } - } - } // check token model mapping modelLimitEnable := c.GetBool("token_model_limit_enabled") if modelLimitEnable { @@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) { } } - userGroup, _ := model.CacheGetUserGroup(userId) - c.Set("group", userGroup) if shouldSelectChannel { channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { @@ -147,13 +83,82 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } - SetupContextForSelectedChannel(c, channel, modelRequest.Model) } } + SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } } +func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { + var modelRequest ModelRequest + shouldSelectChannel := true + var err error + if strings.Contains(c.Request.URL.Path, "/mj/") { + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return nil, false, err + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return nil, false, fmt.Errorf(mjErr.Description) + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return nil, false, fmt.Errorf("无效的请求, 无法解析模型") + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel + } + c.Set("relay_mode", relayMode) + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + return nil, false, err + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = "tts-1" + } else { + modelRequest.Model = "whisper-1" + } + } + } + return &modelRequest, shouldSelectChannel, nil +} + func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id)