fix: 修复流模式错误扣费的问题 (close #95)

This commit is contained in:
1808837298@qq.com
2024-03-06 17:41:55 +08:00
parent 9de2d21e1a
commit 626217fbd4
16 changed files with 166 additions and 35 deletions

View File

@@ -71,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -236,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
}
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,

View File

@@ -173,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -209,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",

View File

@@ -10,17 +10,32 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -38,6 +53,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
//if a.RequestMode == RequestModeCompletion {
// return requestOpenAI2ClaudeComplete(*request), nil
//} else {
// return requestOpenAI2ClaudeMessage(*request), nil
//}
return request, nil
}

View File

@@ -24,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
@@ -44,7 +44,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
@@ -52,6 +54,10 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
return &claudeRequest
}
//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
//
//}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
@@ -167,7 +173,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",

View File

@@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}
if len(geminiResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",

View File

@@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body

View File

@@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",

View File

@@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
}
if TencentResponse.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},

View File

@@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if !zhipuResponse.Success {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",

View File

@@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body