mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 21:33:41 +08:00 
			
		
		
		
	fix: 修复流模式错误扣费的问题 (close #95)
This commit is contained in:
		@@ -24,6 +24,9 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
 | 
			
		||||
	if channel.Type == common.ChannelTypeMidjourney {
 | 
			
		||||
		return errors.New("midjourney channel test is not supported"), nil
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 | 
			
		||||
	w := httptest.NewRecorder()
 | 
			
		||||
	c, _ := gin.CreateTestContext(w)
 | 
			
		||||
@@ -68,11 +71,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		err := relaycommon.RelayErrorHandler(resp)
 | 
			
		||||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
 | 
			
		||||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 | 
			
		||||
	}
 | 
			
		||||
	usage, respErr := adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
	if respErr != nil {
 | 
			
		||||
		return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
 | 
			
		||||
		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 | 
			
		||||
	}
 | 
			
		||||
	if usage == nil {
 | 
			
		||||
		return errors.New("usage is nil"), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -38,24 +38,24 @@ func Relay(c *gin.Context) {
 | 
			
		||||
			retryTimes = common.RetryTimes
 | 
			
		||||
		}
 | 
			
		||||
		if retryTimes > 0 {
 | 
			
		||||
			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
 | 
			
		||||
			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Error.Message))
 | 
			
		||||
		} else {
 | 
			
		||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
				//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
				//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
			}
 | 
			
		||||
			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
			
		||||
			err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
 | 
			
		||||
			c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
				"error": err.OpenAIError,
 | 
			
		||||
				"error": err.Error,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		autoBan := c.GetBool("auto_ban")
 | 
			
		||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
 | 
			
		||||
		if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			service.DisableChannel(channelId, channelName, err.Message)
 | 
			
		||||
			service.DisableChannel(channelId, channelName, err.Error.Message)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -110,7 +110,7 @@ func RelayMidjourney(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
 | 
			
		||||
		//if shouldDisableChannel(&err.OpenAIError) {
 | 
			
		||||
		//if shouldDisableChannel(&err.Error) {
 | 
			
		||||
		//	channelId := c.GetInt("channel_id")
 | 
			
		||||
		//	channelName := c.GetString("channel_name")
 | 
			
		||||
		//	disableChannel(channelId, channelName, err.Result)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										45
									
								
								dto/error.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								dto/error.go
									
									
									
									
									
								
							@@ -8,6 +8,47 @@ type OpenAIError struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIErrorWithStatusCode struct {
 | 
			
		||||
	OpenAIError
 | 
			
		||||
	StatusCode int `json:"status_code"`
 | 
			
		||||
	Error      OpenAIError `json:"error"`
 | 
			
		||||
	StatusCode int         `json:"status_code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GeneralErrorResponse struct {
 | 
			
		||||
	Error    OpenAIError `json:"error"`
 | 
			
		||||
	Message  string      `json:"message"`
 | 
			
		||||
	Msg      string      `json:"msg"`
 | 
			
		||||
	Err      string      `json:"err"`
 | 
			
		||||
	ErrorMsg string      `json:"error_msg"`
 | 
			
		||||
	Header   struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Response struct {
 | 
			
		||||
		Error struct {
 | 
			
		||||
			Message string `json:"message"`
 | 
			
		||||
		} `json:"error"`
 | 
			
		||||
	} `json:"response"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e GeneralErrorResponse) ToMessage() string {
 | 
			
		||||
	if e.Error.Message != "" {
 | 
			
		||||
		return e.Error.Message
 | 
			
		||||
	}
 | 
			
		||||
	if e.Message != "" {
 | 
			
		||||
		return e.Message
 | 
			
		||||
	}
 | 
			
		||||
	if e.Msg != "" {
 | 
			
		||||
		return e.Msg
 | 
			
		||||
	}
 | 
			
		||||
	if e.Err != "" {
 | 
			
		||||
		return e.Err
 | 
			
		||||
	}
 | 
			
		||||
	if e.ErrorMsg != "" {
 | 
			
		||||
		return e.ErrorMsg
 | 
			
		||||
	}
 | 
			
		||||
	if e.Header.Message != "" {
 | 
			
		||||
		return e.Header.Message
 | 
			
		||||
	}
 | 
			
		||||
	if e.Response.Error.Message != "" {
 | 
			
		||||
		return e.Response.Error.Message
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -71,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 | 
			
		||||
 | 
			
		||||
	if aliResponse.Code != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: aliResponse.Message,
 | 
			
		||||
				Type:    aliResponse.Code,
 | 
			
		||||
				Param:   aliResponse.RequestId,
 | 
			
		||||
@@ -236,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
	if aliResponse.Code != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: aliResponse.Message,
 | 
			
		||||
				Type:    aliResponse.Code,
 | 
			
		||||
				Param:   aliResponse.RequestId,
 | 
			
		||||
 
 | 
			
		||||
@@ -173,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 | 
			
		||||
	}
 | 
			
		||||
	if baiduResponse.ErrorMsg != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: baiduResponse.ErrorMsg,
 | 
			
		||||
				Type:    "baidu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
@@ -209,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
 | 
			
		||||
	}
 | 
			
		||||
	if baiduResponse.ErrorMsg != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: baiduResponse.ErrorMsg,
 | 
			
		||||
				Type:    "baidu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
 
 | 
			
		||||
@@ -10,17 +10,32 @@ import (
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RequestModeCompletion = 1
 | 
			
		||||
	RequestModeMessage    = 2
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
	RequestMode int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 | 
			
		||||
		a.RequestMode = RequestModeMessage
 | 
			
		||||
	} else {
 | 
			
		||||
		a.RequestMode = RequestModeCompletion
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
 | 
			
		||||
	if a.RequestMode == RequestModeMessage {
 | 
			
		||||
		return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
@@ -38,6 +53,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	//if a.RequestMode == RequestModeCompletion {
 | 
			
		||||
	//	return requestOpenAI2ClaudeComplete(*request), nil
 | 
			
		||||
	//} else {
 | 
			
		||||
	//	return requestOpenAI2ClaudeMessage(*request), nil
 | 
			
		||||
	//}
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -24,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
	claudeRequest := ClaudeRequest{
 | 
			
		||||
		Model:             textRequest.Model,
 | 
			
		||||
		Prompt:            "",
 | 
			
		||||
@@ -44,7 +44,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
		} else if message.Role == "assistant" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
			
		||||
		} else if message.Role == "system" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
 | 
			
		||||
			if prompt == "" {
 | 
			
		||||
				prompt = message.StringContent()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	prompt += "\n\nAssistant:"
 | 
			
		||||
@@ -52,6 +54,10 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
	return &claudeRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
 | 
			
		||||
//
 | 
			
		||||
//}
 | 
			
		||||
 | 
			
		||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice dto.ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = claudeResponse.Completion
 | 
			
		||||
@@ -167,7 +173,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
			
		||||
	}
 | 
			
		||||
	if claudeResponse.Error.Type != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: claudeResponse.Error.Message,
 | 
			
		||||
				Type:    claudeResponse.Error.Type,
 | 
			
		||||
				Param:   "",
 | 
			
		||||
 
 | 
			
		||||
@@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 | 
			
		||||
	}
 | 
			
		||||
	if len(geminiResponse.Candidates) == 0 {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: "No candidates returned",
 | 
			
		||||
				Type:    "server_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
 
 | 
			
		||||
@@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
			
		||||
	}
 | 
			
		||||
	if textResponse.Error.Type != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: textResponse.Error,
 | 
			
		||||
			StatusCode:  resp.StatusCode,
 | 
			
		||||
			Error:      textResponse.Error,
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	// Reset response body
 | 
			
		||||
 
 | 
			
		||||
@@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 | 
			
		||||
	}
 | 
			
		||||
	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: palmResponse.Error.Message,
 | 
			
		||||
				Type:    palmResponse.Error.Status,
 | 
			
		||||
				Param:   "",
 | 
			
		||||
 
 | 
			
		||||
@@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
 | 
			
		||||
	}
 | 
			
		||||
	if TencentResponse.Error.Code != 0 {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: TencentResponse.Error.Message,
 | 
			
		||||
				Code:    TencentResponse.Error.Code,
 | 
			
		||||
			},
 | 
			
		||||
 
 | 
			
		||||
@@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 | 
			
		||||
	}
 | 
			
		||||
	if !zhipuResponse.Success {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: dto.OpenAIError{
 | 
			
		||||
			Error: dto.OpenAIError{
 | 
			
		||||
				Message: zhipuResponse.Msg,
 | 
			
		||||
				Type:    "zhipu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
 
 | 
			
		||||
@@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 | 
			
		||||
	}
 | 
			
		||||
	if textResponse.Error.Type != "" {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: textResponse.Error,
 | 
			
		||||
			StatusCode:  resp.StatusCode,
 | 
			
		||||
			Error:      textResponse.Error,
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	// Reset response body
 | 
			
		||||
 
 | 
			
		||||
@@ -17,10 +17,10 @@ import (
 | 
			
		||||
 | 
			
		||||
var StopFinishReason = "stop"
 | 
			
		||||
 | 
			
		||||
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
		StatusCode: resp.StatusCode,
 | 
			
		||||
		OpenAIError: dto.OpenAIError{
 | 
			
		||||
		Error: dto.OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
			Code:    "bad_response_status_code",
 | 
			
		||||
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.Open
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 | 
			
		||||
	OpenAIErrorWithStatusCode.Error = textResponse.Error
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -148,10 +149,19 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
		return service.RelayErrorHandler(resp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
 | 
			
		||||
@@ -218,6 +228,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 | 
			
		||||
	return preConsumedQuota, userQuota, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
 | 
			
		||||
	if preConsumedQuota != 0 {
 | 
			
		||||
		go func(ctx context.Context) {
 | 
			
		||||
			// return pre-consumed quota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error return pre-consumed quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
 | 
			
		||||
	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 | 
			
		||||
	promptTokens := usage.PromptTokens
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,13 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -23,7 +27,42 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
 | 
			
		||||
		Code:    code,
 | 
			
		||||
	}
 | 
			
		||||
	return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
		OpenAIError: openAIError,
 | 
			
		||||
		StatusCode:  statusCode,
 | 
			
		||||
		Error:      openAIError,
 | 
			
		||||
		StatusCode: statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
		StatusCode: resp.StatusCode,
 | 
			
		||||
		Error: dto.OpenAIError{
 | 
			
		||||
			Message: "",
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
			Code:    "bad_response_status_code",
 | 
			
		||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var errResponse dto.GeneralErrorResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &errResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if errResponse.Error.Message != "" {
 | 
			
		||||
		// OpenAI format error, so we override the default one
 | 
			
		||||
		errWithStatusCode.Error = errResponse.Error
 | 
			
		||||
	} else {
 | 
			
		||||
		errWithStatusCode.Error.Message = errResponse.ToMessage()
 | 
			
		||||
	}
 | 
			
		||||
	if errWithStatusCode.Error.Message == "" {
 | 
			
		||||
		errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user