package middleware import ( "errors" "fmt" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/model" relayconstant "one-api/relay/constant" "one-api/service" "strconv" "strings" "github.com/gin-gonic/gin" ) type ModelRequest struct { Model string `json:"model"` } func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") modelRequest, shouldSelectChannel, err := getModelRequest(c) if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { // Select a channel for the user // check token model mapping modelLimitEnable := c.GetBool("token_model_limit_enabled") if modelLimitEnable { s, ok := c.Get("token_model_limit") var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) } else { tokenModelLimit = map[string]bool{} } if tokenModelLimit != nil { if _, ok := tokenModelLimit[modelRequest.Model]; !ok { abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) return } } else { // token model limit is empty, all models are not allowed abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } } if shouldSelectChannel { channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) // 如果错误,但是渠道不为空,说明是数据库一致性问题 if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } // 如果错误,而且渠道为空,说明是没有可用渠道 abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) return } if channel == nil { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } } } SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } } func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { var modelRequest ModelRequest shouldSelectChannel := true var err error if strings.Contains(c.Request.URL.Path, "/mj/") { relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || relayMode == relayconstant.RelayModeMidjourneyNotify || relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { shouldSelectChannel = false } else { midjourneyRequest := dto.MidjourneyRequest{} err = common.UnmarshalBodyReusable(c, &midjourneyRequest) if err != nil { abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) return nil, false, err } midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) if mjErr != nil { abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) return nil, false, fmt.Errorf(mjErr.Description) } if midjourneyModel == "" { if !success { abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") return nil, false, fmt.Errorf("无效的请求, 无法解析模型") } else { // task fetch, task fetch by condition, notify shouldSelectChannel = false } } modelRequest.Model = midjourneyModel } c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/suno/") { relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path) if relayMode == relayconstant.RelayModeSunoFetch || relayMode == relayconstant.RelayModeSunoFetchByID { shouldSelectChannel = false } else { modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) modelRequest.Model = modelName } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return nil, false, errors.New("无效的请求, " + err.Error()) } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" } } if strings.HasSuffix(c.Request.URL.Path, "embeddings") { if modelRequest.Model == "" { modelRequest.Model = c.Param("model") } } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { modelRequest.Model = "dall-e" } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if modelRequest.Model == "" { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { modelRequest.Model = "tts-1" } else { modelRequest.Model = "whisper-1" } } } return &modelRequest, shouldSelectChannel, nil } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { c.Set("original_model", modelName) // for retry if channel == nil { return } c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) ban := true // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false } if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } c.Set("auto_ban", ban) c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) case common.ChannelTypeGemini: c.Set("api_version", channel.Other) case common.ChannelTypeAli: c.Set("plugin", channel.Other) case common.ChannelCloudflare: c.Set("api_version", channel.Other) } }