From 626217fbd48fb7a4614d529da2af968adf89a470 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 6 Mar 2024 17:41:55 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B5=81=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E9=94=99=E8=AF=AF=E6=89=A3=E8=B4=B9=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20(close=20#95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 7 ++-- controller/relay.go | 16 ++++----- dto/error.go | 45 ++++++++++++++++++++++-- relay/channel/ali/relay-ali.go | 4 +-- relay/channel/baidu/relay-baidu.go | 4 +-- relay/channel/claude/adaptor.go | 24 +++++++++++-- relay/channel/claude/relay-claude.go | 12 +++++-- relay/channel/gemini/relay-gemini.go | 2 +- relay/channel/openai/relay-openai.go | 4 +-- relay/channel/palm/relay-palm.go | 2 +- relay/channel/tencent/relay-tencent.go | 2 +- relay/channel/zhipu/relay-zhipu.go | 2 +- relay/channel/zhipu_4v/relay-zhipu_v4.go | 4 +-- relay/common/relay_utils.go | 8 ++--- relay/relay-text.go | 22 ++++++++++++ service/error.go | 43 ++++++++++++++++++++-- 16 files changed, 166 insertions(+), 35 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 2c24d2f..79929e6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -24,6 +24,9 @@ import ( ) func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { + if channel.Type == common.ChannelTypeMidjourney { + return errors.New("midjourney channel test is not supported"), nil + } common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -68,11 +71,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } if resp.StatusCode != http.StatusOK { err := relaycommon.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError + return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } if usage == nil { return errors.New("usage is nil"), nil diff --git a/controller/relay.go b/controller/relay.go index a79e46c..f3d2772 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -38,24 +38,24 @@ func Relay(c *gin.Context) { retryTimes = common.RetryTimes } if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Error.Message)) } else { if err.StatusCode == http.StatusTooManyRequests { - //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + //err.Error.Message = "当前分组上游负载已饱和,请稍后再试" } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, + "error": err.Error, }) } channelId := c.GetInt("channel_id") autoBan := c.GetBool("auto_ban") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { + if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") - service.DisableChannel(channelId, channelName, err.Message) + service.DisableChannel(channelId, channelName, err.Error.Message) } } } @@ -110,7 +110,7 @@ func RelayMidjourney(c *gin.Context) { } channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) - //if shouldDisableChannel(&err.OpenAIError) { + //if shouldDisableChannel(&err.Error) { // channelId := c.GetInt("channel_id") // channelName := c.GetString("channel_name") // disableChannel(channelId, channelName, err.Result) diff --git a/dto/error.go b/dto/error.go index bfb3376..e82e051 100644 --- a/dto/error.go +++ b/dto/error.go @@ -8,6 +8,47 @@ type OpenAIError struct { } type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` + Error OpenAIError `json:"error"` + StatusCode int `json:"status_code"` +} + +type GeneralErrorResponse struct { + Error OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" } diff --git a/relay/channel/ali/relay-ali.go b/relay/channel/ali/relay-ali.go index bc5395b..36e8d9e 100644 --- a/relay/channel/ali/relay-ali.go +++ b/relay/channel/ali/relay-ali.go @@ -71,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW if aliResponse.Code != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -236,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus } if aliResponse.Code != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 92bf043..6f773ba 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -173,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat } if baiduResponse.ErrorMsg != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -209,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro } if baiduResponse.ErrorMsg != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index a7245ee..2ed1e2e 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -10,17 +10,32 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "strings" +) + +const ( + RequestModeCompletion = 1 + RequestModeMessage = 2 ) type Adaptor struct { + RequestMode int } func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { - + if strings.HasPrefix(info.UpstreamModelName, "claude-3") { + a.RequestMode = RequestModeMessage + } else { + a.RequestMode = RequestModeCompletion + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + if a.RequestMode == RequestModeMessage { + return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil + } else { + return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { @@ -38,6 +53,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen if request == nil { return nil, errors.New("request is nil") } + //if a.RequestMode == RequestModeCompletion { + // return requestOpenAI2ClaudeComplete(*request), nil + //} else { + // return requestOpenAI2ClaudeMessage(*request), nil + //} return request, nil } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 186564f..1a285d5 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -24,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { +func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -44,7 +44,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + if prompt == "" { + prompt = message.StringContent() + } } } prompt += "\n\nAssistant:" @@ -52,6 +54,10 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { return &claudeRequest } +//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { +// +//} + func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = claudeResponse.Completion @@ -167,7 +173,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model } if claudeResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, Param: "", diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 83ede7f..adf118a 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo } if len(geminiResponse.Candidates) == 0 { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: "No candidates returned", Type: "server_error", Param: "", diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b0f3aa5..9624606 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } if textResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, + Error: textResponse.Error, + StatusCode: resp.StatusCode, }, nil } // Reset response body diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 20706df..d775651 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index b990c6f..6f4cd91 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt } if TencentResponse.Error.Code != 0 { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: TencentResponse.Error.Message, Code: TencentResponse.Error.Code, }, diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index d6d82f1..8a54842 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat } if !zhipuResponse.Success { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: zhipuResponse.Msg, Type: "zhipu_error", Param: "", diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index af9b1d8..34b4792 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat } if textResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, + Error: textResponse.Error, + StatusCode: resp.StatusCode, }, nil } // Reset response body diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 8e75d24..3ee2dfa 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -17,10 +17,10 @@ import ( var StopFinishReason = "stop" -func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ +func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { + OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, - OpenAIError: dto.OpenAIError{ + Error: dto.OpenAIError{ Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), Type: "upstream_error", Code: "bad_response_status_code", @@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.Open if err != nil { return } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error + OpenAIErrorWithStatusCode.Error = textResponse.Error return } diff --git a/relay/relay-text.go b/relay/relay-text.go index d38afaa..bd24642 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -148,10 +149,19 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + return service.RelayErrorHandler(resp) + } + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) return openaiErr } postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) @@ -218,6 +228,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return preConsumedQuota, userQuota, nil } +func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false) + if err != nil { + common.SysError("error return pre-consumed quota: " + err.Error()) + } + }(c) + } +} + func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens diff --git a/service/error.go b/service/error.go index 89d200c..303bcf7 100644 --- a/service/error.go +++ b/service/error.go @@ -1,9 +1,13 @@ package service import ( + "encoding/json" "fmt" + "io" + "net/http" "one-api/common" "one-api/dto" + "strconv" "strings" ) @@ -23,7 +27,42 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError Code: code, } return &dto.OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, + Error: openAIError, + StatusCode: statusCode, } } + +func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { + errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: dto.OpenAIError{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse dto.GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + errWithStatusCode.Error = errResponse.Error + } else { + errWithStatusCode.Error.Message = errResponse.ToMessage() + } + if errWithStatusCode.Error.Message == "" { + errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +}