diff --git a/controller/channel-test.go b/controller/channel-test.go index 6f82cd7..2174ff1 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -25,7 +25,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { +func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil @@ -58,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error - return err, &openaiErr + return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[testModel] != "" { testModel = modelMap[testModel] @@ -104,11 +103,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } if resp != nil && resp.StatusCode != http.StatusOK { err := relaycommon.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + return fmt.Errorf("%s", respErr.Error.Message), respErr } if usage == nil { return errors.New("usage is nil"), nil @@ -222,7 +221,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, "") + err, openaiWithStatusErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -233,14 +232,10 @@ func testAllChannels(notify bool) error { } // request error disables the channel - if openaiErr != nil { - err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) - openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ - StatusCode: -1, - Error: *openaiErr, - LocalError: false, - } - ban = service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) + if openaiWithStatusErr != nil { + oaiErr := openaiWithStatusErr.Error + err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) + ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } // parse *int to bool @@ -254,7 +249,7 @@ func testAllChannels(notify bool) error { } // enable channel - if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { + if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { service.EnableChannel(channel.Id, channel.Name) } diff --git a/service/channel.go b/service/channel.go index 76be271..5716a6d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } if err != nil { return false } - if openAIErr != nil { + if openaiWithStatusErr != nil { return false } if status != common.ChannelStatusAutoDisabled {