From 9896ba0a642f195f94af5dfd51af9edc84e10532 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 8 Jul 2024 01:52:40 +0800 Subject: [PATCH] feat: support aws stream_options --- relay/channel/aws/adaptor.go | 2 +- relay/channel/aws/relay-aws.go | 15 ++++++++++++--- relay/channel/claude/relay-claude.go | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 4de3a3a..6452392 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -68,7 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = awsStreamHandler(c, info, a.RequestMode) + err, usage = awsStreamHandler(c, resp, info, a.RequestMode) } else { err, usage = awsHandler(c, info, a.RequestMode) } diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 125f6f2..d549205 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -13,6 +13,7 @@ import ( relaymodel "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" + "one-api/service" "strings" "time" @@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return nil, &usage } -func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { +func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { awsCli, err := newAwsClient(c, info) if err != nil { return wrapErr(errors.Wrap(err, "newAwsClient")), nil @@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } @@ -214,6 +214,15 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i return false } }) - + response := service.GenerateFinalUsageResponse(id, createdTime, model, usage) + err = service.ObjectData(c, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) + } + service.Done(c) + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } return nil, &usage } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 35099ed..9090146 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -352,7 +352,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) err := service.ObjectData(c, response) if err != nil { - common.SysError(err.Error()) + common.SysError("send final response failed: " + err.Error()) } service.Done(c) err = resp.Body.Close()