diff --git a/controller/relay.go b/controller/relay.go index 0c79015..66339f4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -39,38 +39,28 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) - retryTimes := common.RetryTimes requestId := c.GetString(common.RequestIdKey) - channelId := c.GetInt("channel_id") - channelType := c.GetInt("channel_type") - channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") - openaiErr := relayHandler(c, relayMode) - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - if openaiErr != nil { - go processChannelError(c, channelId, channelType, channelName, openaiErr) - } else { - retryTimes = 0 - } - for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, fmt.Sprintf("Failed to get channel: %s", err.Error())) break } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id)) - c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - openaiErr = relayHandler(c, relayMode) - if openaiErr != nil { - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break } } useChannel := c.GetStringSlice("use_channel") @@ -90,7 +80,36 @@ func Relay(c *gin.Context) { } } -func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { +func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relayHandler(c, relayMode) +} + +func addUsedChannel(c *gin.Context, channelId int) { + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) +} + +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { + if retryCount == 0 { + return &model.Channel{ + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + }, nil + } + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + if err != nil { + return nil, err + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + return channel, nil +} + +func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { if openaiErr == nil { return false } @@ -114,6 +133,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } if openaiErr.StatusCode == http.StatusBadRequest { + channelType := c.GetInt("channel_type") + if channelType == common.ChannelTypeAnthropic { + return true + } return false } if openaiErr.StatusCode == 408 { diff --git a/middleware/distributor.go b/middleware/distributor.go index f150b41..fad868d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -184,7 +184,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel == nil { return } - c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type)