From e8b93ed6ec7f7e50fe3f638e6d45479e8e6303b8 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 8 Jul 2024 01:45:43 +0800 Subject: [PATCH] feat: support claude stream_options --- dto/text_response.go | 1 + relay/channel/claude/relay-claude.go | 21 ++++++++++++--------- relay/channel/openai/adaptor.go | 2 +- service/usage_helpr.go | 12 ++++++++++++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/dto/text_response.go b/dto/text_response.go index 9b45368..3310d02 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -102,6 +102,7 @@ type ChatCompletionsStreamResponse struct { Model string `json:"model"` SystemFingerprint *string `json:"system_fingerprint"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *Usage `json:"usage"` } type ChatCompletionsStreamResponseSimple struct { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f34d9a7..35099ed 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -330,22 +330,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. response.Created = createdTime response.Model = info.UpstreamModelName - jsonStr, err := json.Marshal(response) + err = service.ObjectData(c, response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true + common.SysError(err.Error()) } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) return true case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { @@ -356,6 +349,16 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) } } + response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.SysError(err.Error()) + } + service.Done(c) + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } return nil, usage } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index ba4b66a..0c15793 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return nil, errors.New("request is nil") } // 如果不支持StreamOptions,将StreamOptions设置为nil - if !a.SupportStreamOptions { + if !a.SupportStreamOptions || !request.Stream { request.StreamOptions = nil } else { // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 15e3226..86fc456 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -24,3 +24,15 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err } + +func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse { + return &dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: createAt, + Model: model, + SystemFingerprint: nil, + Choices: nil, + Usage: &usage, + } +}