diff --git a/controller/channel-test.go b/controller/channel-test.go index 37c1876..ea82578 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -228,7 +228,7 @@ func testAllChannels(notify bool) error { Error: *openaiErr, LocalError: false, } - if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban { + if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban { service.DisableChannel(channel.Id, channel.Name, err.Error()) } if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { diff --git a/controller/relay.go b/controller/relay.go index e329aed..03853c1 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -40,12 +40,13 @@ func Relay(c *gin.Context) { retryTimes := common.RetryTimes requestId := c.GetString(common.RequestIdKey) channelId := c.GetInt("channel_id") + channelType := c.GetInt("channel_type") group := c.GetString("group") originalModel := c.GetString("original_model") openaiErr := relayHandler(c, relayMode) c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) if openaiErr != nil { - go processChannelError(c, channelId, openaiErr) + go processChannelError(c, channelId, channelType, openaiErr) } else { retryTimes = 0 } @@ -66,7 +67,7 @@ func Relay(c *gin.Context) { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) openaiErr = relayHandler(c, relayMode) if openaiErr != nil { - go processChannelError(c, channelId, openaiErr) + go processChannelError(c, channelId, channel.Type, openaiErr) } } useChannel := c.GetStringSlice("use_channel") @@ -125,10 +126,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } -func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) { +func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) { autoBan := c.GetBool("auto_ban") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) - if service.ShouldDisableChannel(err) && autoBan { + if service.ShouldDisableChannel(channelType, err) && autoBan { channelName := c.GetString("channel_name") service.DisableChannel(channelId, channelName, err.Error.Message) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 4862a48..61361e6 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -178,6 +178,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) + c.Set("channel_type", channel.Type) ban := true // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 09b67cf..9137721 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -73,14 +73,14 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { - return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota @@ -90,7 +90,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if preConsumedQuota > 0 { userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } diff --git a/relay/relay-image.go b/relay/relay-image.go index cbd226a..d83ec26 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -147,7 +147,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N if userQuota-quota < 0 { - return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) diff --git a/service/channel.go b/service/channel.go index 628605b..76be271 100644 --- a/service/channel.go +++ b/service/channel.go @@ -24,7 +24,7 @@ func EnableChannel(channelId int, channelName string) { notifyRootUser(subject, content) } -func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool { +func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool { if !common.AutomaticDisableChannelEnabled { return false } @@ -34,9 +34,15 @@ func ShouldDisableChannel(err *relaymodel.OpenAIErrorWithStatusCode) bool { if err.LocalError { return false } - if err.StatusCode == http.StatusUnauthorized || err.StatusCode == http.StatusForbidden { + if err.StatusCode == http.StatusUnauthorized { return true } + if err.StatusCode == http.StatusForbidden { + switch channelType { + case common.ChannelTypeGemini: + return true + } + } switch err.Error.Code { case "invalid_api_key": return true