From f0907bf60ae4350ce404e16413f77dbe1e946356 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Mon, 7 Oct 2024 20:35:33 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=83=A8=E5=88=86=E6=83=85=E5=86=B5?= =?UTF-8?q?=E7=BC=BA=E5=B0=91=E8=BF=94=E5=9B=9E=E9=A2=84=E6=89=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e) --- relay/channel/openai/relay-openai.go | 8 ++++---- relay/relay-audio.go | 10 +++++++--- relay/relay-text.go | 12 +++++++----- relay/relay_rerank.go | 12 ++++++++---- relay/websocket.go | 9 +++++++-- 5 files changed, 33 insertions(+), 18 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bb19684..2b237ea 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -391,7 +391,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage := &dto.RealtimeUsage{} sumUsage := &dto.RealtimeUsage{} - go func() { + gopool.Go(func() { for { select { case <-c.Done(): @@ -444,9 +444,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } } } - }() + }) - go func() { + gopool.Go(func() { for { select { case <-c.Done(): @@ -541,7 +541,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } } } - }() + }) select { case <-clientClosed: diff --git a/relay/relay-audio.go b/relay/relay-audio.go index b65f612..4bf916b 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return audioRequest, nil } -func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { +func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) @@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() // map model name modelMapping := c.GetString("model_mapping") @@ -128,8 +133,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/relay-text.go b/relay/relay-text.go index 7bb7d99..463947f 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) return textRequest, nil } -func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { +func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) @@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if openaiErr != nil { return openaiErr } - + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() includeUsage := false // 判断用户是否需要返回使用情况 if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { @@ -190,8 +194,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr @@ -200,7 +203,6 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 4cb1c98..a627b78 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) var rerankRequest *dto.RerankRequest @@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if openaiErr != nil { return openaiErr } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) @@ -104,8 +110,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr @@ -114,7 +119,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/websocket.go b/relay/websocket.go index 09d8298..247169e 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -30,7 +30,7 @@ import ( // return realtimeEvent, nil //} -func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode { +func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) // get & validate textRequest 获取并验证文本请求 @@ -96,6 +96,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod return openaiErr } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) @@ -118,7 +124,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr