fix: 部分情况缺少返回预扣

(cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e)
This commit is contained in:
Xyfacai 2024-10-07 20:35:33 +08:00 committed by CalciumIon
parent e5c05d77b7
commit f0907bf60a
5 changed files with 33 additions and 18 deletions

View File

@ -391,7 +391,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage := &dto.RealtimeUsage{} localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{} sumUsage := &dto.RealtimeUsage{}
go func() { gopool.Go(func() {
for { for {
select { select {
case <-c.Done(): case <-c.Done():
@ -444,9 +444,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
} }
} }
} }
}() })
go func() { gopool.Go(func() {
for { for {
select { select {
case <-c.Done(): case <-c.Done():
@ -541,7 +541,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
} }
} }
} }
}() })
select { select {
case <-clientClosed: case <-clientClosed:

View File

@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return audioRequest, nil return audioRequest, nil
} }
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo) audioRequest, err := getAndValidAudioRequest(c, relayInfo)
@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
@ -128,8 +133,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) openaiErr = service.RelayErrorHandler(httpResp)
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
return textRequest, nil return textRequest, nil
} }
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
includeUsage := false includeUsage := false
// 判断用户是否需要返回使用情况 // 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
@ -190,8 +194,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) openaiErr = service.RelayErrorHandler(httpResp)
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
@ -200,7 +203,6 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token return token
} }
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest var rerankRequest *dto.RerankRequest
@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if openaiErr != nil { if openaiErr != nil {
return openaiErr return openaiErr
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@ -104,8 +110,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) openaiErr = service.RelayErrorHandler(httpResp)
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
@ -114,7 +119,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@ -30,7 +30,7 @@ import (
// return realtimeEvent, nil // return realtimeEvent, nil
//} //}
func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode { func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws) relayInfo := relaycommon.GenRelayInfoWs(c, ws)
// get & validate textRequest 获取并验证文本请求 // get & validate textRequest 获取并验证文本请求
@ -96,6 +96,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
return openaiErr return openaiErr
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@ -118,7 +124,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr