From 4b60528c5fbf4f6e3907d85ccc2098c16390d9be Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 4 Apr 2024 16:35:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9C=AC=E5=9C=B0=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/gin.go | 27 +++++++-- controller/relay.go | 111 ++++++++++++++++++++++++++++--------- dto/error.go | 1 + middleware/auth.go | 2 +- middleware/distributor.go | 65 ++++++++++++---------- model/ability.go | 23 +++----- model/cache.go | 42 ++++++++------ relay/common/relay_info.go | 1 + relay/relay-text.go | 16 +++--- service/channel.go | 24 +++++++- service/error.go | 6 ++ 11 files changed, 215 insertions(+), 103 deletions(-) diff --git a/common/gin.go b/common/gin.go index ffa1e21..4a909df 100644 --- a/common/gin.go +++ b/common/gin.go @@ -5,18 +5,37 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "io" + "strings" ) -func UnmarshalBodyReusable(c *gin.Context, v any) error { +const KeyRequestBody = "key_request_body" + +func GetRequestBody(c *gin.Context) ([]byte, error) { + requestBody, _ := c.Get(KeyRequestBody) + if requestBody != nil { + return requestBody.([]byte), nil + } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - return err + return nil, err } - err = c.Request.Body.Close() + _ = c.Request.Body.Close() + c.Set(KeyRequestBody, requestBody) + return requestBody.([]byte), nil +} + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := GetRequestBody(c) if err != nil { return err } - err = json.Unmarshal(requestBody, &v) + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + err = json.Unmarshal(requestBody, &v) + } else { + // skip for now + // TODO: someday non json request have variant model, we will need to implementation this + } if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 9f866b8..5c89f91 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,21 +1,23 @@ package controller import ( + "bytes" "fmt" "github.com/gin-gonic/gin" + "io" "log" "net/http" "one-api/common" "one-api/dto" + "one-api/middleware" + "one-api/model" "one-api/relay" "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/service" - "strconv" ) -func Relay(c *gin.Context) { - relayMode := constant.Path2RelayMode(c.Request.URL.Path) +func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: @@ -29,33 +31,88 @@ func Relay(c *gin.Context) { default: err = relay.TextHelper(c) } - if err != nil { - requestId := c.GetString(common.RequestIdKey) - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes + return err +} + +func Relay(c *gin.Context) { + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + retryTimes := common.RetryTimes + requestId := c.GetString(common.RequestIdKey) + channelId := c.GetInt("channel_id") + group := c.GetString("group") + originalModel := c.GetString("original_model") + openaiErr := relayHandler(c, relayMode) + retryLogStr := fmt.Sprintf("重试:%d", channelId) + if openaiErr != nil { + go processChannelError(c, channelId, openaiErr) + } else { + retryTimes = 0 + } + for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + if err != nil { + common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + break } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.StatusCode == http.StatusTooManyRequests { - //err.Error.Message = "当前分组上游负载已饱和,请稍后再试" - } - err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) - c.JSON(err.StatusCode, gin.H{ - "error": err.Error, - }) + channelId = channel.Id + retryLogStr += fmt.Sprintf("->%d", channel.Id) + common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + openaiErr = relayHandler(c, relayMode) + if openaiErr != nil { + go processChannelError(c, channelId, openaiErr) } - channelId := c.GetInt("channel_id") - autoBan := c.GetBool("auto_ban") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message)) - // https://platform.openai.com/docs/guides/error-codes/api-errors - if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { - channelId := c.GetInt("channel_id") - channelName := c.GetString("channel_name") - service.DisableChannel(channelId, channelName, err.Error.Message) + } + common.LogInfo(c.Request.Context(), retryLogStr) + + if openaiErr != nil { + if openaiErr.StatusCode == http.StatusTooManyRequests { + openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } + openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) + c.JSON(openaiErr.StatusCode, gin.H{ + "error": openaiErr.Error, + }) + } +} + +func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { + if openaiErr == nil { + return false + } + if retryTimes <= 0 { + return false + } + if _, ok := c.Get("specific_channel_id"); ok { + return false + } + if openaiErr.StatusCode == http.StatusTooManyRequests { + return true + } + if openaiErr.StatusCode/100 == 5 { + return true + } + if openaiErr.StatusCode == http.StatusBadRequest { + return false + } + if openaiErr.LocalError { + return false + } + if openaiErr.StatusCode/100 == 2 { + return false + } + return true +} + +func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) { + autoBan := c.GetBool("auto_ban") + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message)) + if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { + channelName := c.GetString("channel_name") + service.DisableChannel(channelId, channelName, err.Error.Message) } } diff --git a/dto/error.go b/dto/error.go index e82e051..b347f6a 100644 --- a/dto/error.go +++ b/dto/error.go @@ -10,6 +10,7 @@ type OpenAIError struct { type OpenAIErrorWithStatusCode struct { Error OpenAIError `json:"error"` StatusCode int `json:"status_code"` + LocalError bool } type GeneralErrorResponse struct { diff --git a/middleware/auth.go b/middleware/auth.go index 4b865c2..686f2d9 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -127,7 +127,7 @@ func TokenAuth() func(c *gin.Context) { } if len(parts) > 1 { if model.IsAdmin(token.UserId) { - c.Set("channelId", parts[1]) + c.Set("specific_channel_id", parts[1]) } else { abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return diff --git a/middleware/distributor.go b/middleware/distributor.go index 10696a9..4db683f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel - channelId, ok := c.Get("channelId") + channelId, ok := c.Get("specific_channel_id") if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) { userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) if shouldSelectChannel { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) // 如果错误,但是渠道不为空,说明是数据库一致性问题 @@ -147,36 +147,41 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - if nil != channel.OpenAIOrganization { - c.Set("channel_organization", *channel.OpenAIOrganization) - } - c.Set("auto_ban", ban) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // TODO: api_version统一 - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } + SetupContextForSelectedChannel(c, channel, modelRequest.Model) } } c.Next() } } + +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + if nil != channel.OpenAIOrganization { + c.Set("channel_organization", *channel.OpenAIOrganization) + } + c.Set("auto_ban", ban) + c.Set("model_mapping", channel.GetModelMapping()) + c.Set("original_model", modelName) // for retry + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 + switch channel.Type { + case common.ChannelTypeAzure: + c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) + } +} diff --git a/model/ability.go b/model/ability.go index b79978d..285ce15 100644 --- a/model/ability.go +++ b/model/ability.go @@ -52,21 +52,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { // Randomly choose one weightSum := uint(0) for _, ability_ := range abilities { - weightSum += ability_.Weight + weightSum += ability_.Weight + 10 } - if weightSum == 0 { - // All weight is 0, randomly choose one - channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId - } else { - // Randomly choose one - weight := common.GetRandomInt(int(weightSum)) - for _, ability_ := range abilities { - weight -= int(ability_.Weight) - //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) - if weight <= 0 { - channel.Id = ability_.ChannelId - break - } + // Randomly choose one + weight := common.GetRandomInt(int(weightSum)) + for _, ability_ := range abilities { + weight -= int(ability_.Weight) + //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) + if weight <= 0 { + channel.Id = ability_.ChannelId + break } } } else { diff --git a/model/cache.go b/model/cache.go index 8294e73..78bdc17 100644 --- a/model/cache.go +++ b/model/cache.go @@ -265,7 +265,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { if strings.HasPrefix(model, "gpt-4-gizmo") { model = "gpt-4-gizmo-*" } @@ -280,15 +280,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } - endIdx := len(channels) - // choose by priority - firstChannel := channels[0] - if firstChannel.GetPriority() > 0 { - for i := range channels { - if channels[i].GetPriority() != firstChannel.GetPriority() { - endIdx = i - break - } + + uniquePriorities := make(map[int]bool) + for _, channel := range channels { + uniquePriorities[int(channel.GetPriority())] = true + } + var sortedUniquePriorities []int + for priority := range uniquePriorities { + sortedUniquePriorities = append(sortedUniquePriorities, priority) + } + sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities))) + + if retry >= len(uniquePriorities) { + retry = len(uniquePriorities) - 1 + } + targetPriority := int64(sortedUniquePriorities[retry]) + + // get the priority for the given retry number + var targetChannels []*Channel + for _, channel := range channels { + if channel.GetPriority() == targetPriority { + targetChannels = append(targetChannels, channel) } } @@ -296,20 +308,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error smoothingFactor := 10 // Calculate the total weight of all channels up to endIdx totalWeight := 0 - for _, channel := range channels[:endIdx] { + for _, channel := range targetChannels { totalWeight += channel.GetWeight() + smoothingFactor } - - //if totalWeight == 0 { - // // If all weights are 0, select a channel randomly - // return channels[rand.Intn(endIdx)], nil - //} - // Generate a random value in the range [0, totalWeight) randomWeight := rand.Intn(totalWeight) // Find a channel based on its weight - for _, channel := range channels[:endIdx] { + for _, channel := range targetChannels { randomWeight -= channel.GetWeight() + smoothingFactor if randomWeight < 0 { return channel, nil diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 27ed9a9..7ae9dd4 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -31,6 +31,7 @@ type RelayInfo struct { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") userId := c.GetInt("id") group := c.GetString("group") diff --git a/relay/relay-text.go b/relay/relay-text.go index ff653ff..879f40a 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { textRequest, err := getAndValidateTextRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } // map model name @@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[textRequest.Model] != "" { textRequest.Model = modelMap[textRequest.Model] @@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // count messages token error 计算promptTokens错误 if err != nil { if sensitiveTrigger { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) } return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } @@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode) + return service.RelayErrorHandler(resp) } usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) @@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota <= 0 || userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 @@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } return preConsumedQuota, userQuota, nil diff --git a/service/channel.go b/service/channel.go index b9a7627..6ce444d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -6,6 +6,7 @@ import ( "one-api/common" relaymodel "one-api/dto" "one-api/model" + "strings" ) // disable & notify @@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { if statusCode == http.StatusUnauthorized { return true } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { + switch err.Code { + case "invalid_api_key": + return true + case "account_deactivated": + return true + case "billing_not_active": + return true + } + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true } return false diff --git a/service/error.go b/service/error.go index cda26b3..39eb0f9 100644 --- a/service/error.go +++ b/service/error.go @@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError } } +func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { + openaiErr := OpenAIErrorWrapper(err, code, statusCode) + openaiErr.LocalError = true + return openaiErr +} + func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode,