From d1c8947851c001aad8d39a22db84b404be421c4d Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 25 Apr 2024 23:57:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=A7=84=E8=8C=83claude=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/text_response.go | 21 ++++++++++++++++++--- relay/channel/ali/relay-ali.go | 4 ++-- relay/channel/baidu/relay-baidu.go | 2 +- relay/channel/claude/relay-claude.go | 12 ++++++++---- relay/channel/cohere/relay-cohere.go | 2 +- relay/channel/gemini/relay-gemini.go | 4 ++-- relay/channel/openai/relay-openai.go | 4 ++-- relay/channel/palm/relay-palm.go | 2 +- relay/channel/tencent/relay-tencent.go | 4 ++-- relay/channel/xunfei/relay-xunfei.go | 2 +- relay/channel/zhipu/relay-zhipu.go | 4 ++-- 11 files changed, 40 insertions(+), 21 deletions(-) diff --git a/dto/text_response.go b/dto/text_response.go index a589d75..617c375 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -54,17 +54,32 @@ type OpenAIEmbeddingResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason"` Index int `json:"index,omitempty"` } type ChatCompletionsStreamResponseChoiceDelta struct { - Content string `json:"content"` + Content *string `json:"content,omitempty"` Role string `json:"role,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } +func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool { + return c.Content == nil && len(c.ToolCalls) == 0 +} + +func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { + c.Content = &s +} + +func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string { + if c.Content == nil { + return "" + } + return *c.Content +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/relay/channel/ali/relay-ali.go b/relay/channel/ali/relay-ali.go index e087eea..4280b1c 100644 --- a/relay/channel/ali/relay-ali.go +++ b/relay/channel/ali/relay-ali.go @@ -136,7 +136,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text + choice.Delta.SetContentString(aliResponse.Output.Text) if aliResponse.Output.FinishReason != "null" { finishReason := aliResponse.Output.FinishReason choice.FinishReason = &finishReason @@ -199,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText)) lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 6f773ba..f1ceab3 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -57,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse { func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = baiduResponse.Result + choice.Delta.SetContentString(baiduResponse.Result) if baiduResponse.IsEnd { choice.FinishReason = &relaycommon.StopFinishReason } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 015b645..859b7b1 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -171,8 +171,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) var choice dto.ChatCompletionsStreamResponseChoice if reqMode == RequestModeCompletion { - choice.Delta.Content = claudeResponse.Completion - choice.Delta.Role = "assistant" + choice.Delta.SetContentString(claudeResponse.Completion) finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) if finishReason != "null" { choice.FinishReason = &finishReason @@ -182,10 +181,12 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Id = claudeResponse.Message.Id response.Model = claudeResponse.Message.Model claudeUsage = &claudeResponse.Message.Usage + } else if claudeResponse.Type == "content_block_start" { + choice.Delta.SetContentString("") + choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_delta" { choice.Index = claudeResponse.Index - choice.Delta.Content = claudeResponse.Delta.Text - choice.Delta.Role = "assistant" + choice.Delta.SetContentString(claudeResponse.Delta.Text) } else if claudeResponse.Type == "message_delta" { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) if finishReason != "null" { @@ -194,12 +195,15 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* claudeUsage = &claudeResponse.Usage } else if claudeResponse.Type == "message_stop" { return nil, nil + } else { + return nil, nil } } if claudeUsage == nil { claudeUsage = &ClaudeUsage{} } response.Choices = append(response.Choices, choice) + return &response, claudeUsage } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index a21d4a9..463e8b1 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -117,7 +117,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, { Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant", - Content: cohereResp.Text, + Content: &cohereResp.Text, }, Index: 0, }, diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 4a10a73..ee9301d 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -151,7 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = geminiResponse.GetResponseText() + choice.Delta.SetContentString(geminiResponse.GetResponseText()) choice.FinishReason = &relaycommon.StopFinishReason var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -203,7 +203,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr err := json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = dummy.Content + choice.Delta.SetContentString(dummy.Content) response := dto.ChatCompletionsStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 5469ed7..d627575 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -68,7 +68,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) if err == nil { for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.Content) + responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > toolCount { toolCount = len(choice.Delta.ToolCalls) @@ -84,7 +84,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d } else { for _, streamResponse := range streamResponses { for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.Content) + responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > toolCount { toolCount = len(choice.Delta.ToolCalls) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 3a7d4fa..6933d6f 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -61,7 +61,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { - choice.Delta.Content = palmResponse.Candidates[0].Content + choice.Delta.SetContentString(palmResponse.Candidates[0].Content) } choice.FinishReason = &relaycommon.StopFinishReason var response dto.ChatCompletionsStreamResponse diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 6f4cd91..c22b545 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -86,7 +86,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha } if len(TencentResponse.Choices) > 0 { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = TencentResponse.Choices[0].Delta.Content + choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content) if TencentResponse.Choices[0].FinishReason == "stop" { choice.FinishReason = &relaycommon.StopFinishReason } @@ -138,7 +138,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += response.Choices[0].Delta.GetContentString() } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 1690e96..7cb6c8a 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -87,7 +87,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo } } var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content) if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &relaycommon.StopFinishReason } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 8a54842..5ef9d7a 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -126,7 +126,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse { func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = zhipuResponse + choice.Delta.SetContentString(zhipuResponse) response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), @@ -138,7 +138,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = "" + choice.Delta.SetContentString("") choice.FinishReason = &relaycommon.StopFinishReason response := dto.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId,