From 20da8228df25f330ecec7f4d6d7b0dcf38124d83 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 14:46:25 +0800 Subject: [PATCH] feat: update stream_options --- relay/channel/openai/relay-openai.go | 33 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 807f4b1..6058fd4 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -16,6 +16,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strings" + "sync" "time" ) @@ -41,7 +42,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel stopChan := make(chan bool) defer close(stopChan) - + var ( + lastStreamData string + mu sync.Mutex + ) gopool.Go(func() { for scanner.Scan() { info.SetFirstResponseTime() @@ -53,14 +57,19 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if data[:6] != "data: " && data[:6] != "[DONE]" { continue } + mu.Lock() data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - err := service.StringData(c, data) - if err != nil { - common.LogError(c, "streaming error: "+err.Error()) + if lastStreamData != "" { + err := service.StringData(c, lastStreamData) + if err != nil { + common.LogError(c, "streaming error: "+err.Error()) + } } + lastStreamData = data streamItems = append(streamItems, data) } + mu.Unlock() } common.SafeSendBool(stopChan, true) }) @@ -73,6 +82,22 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // 正常结束 } + shouldSendLastResp := true + var lastStreamResponse dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) + if err == nil { + if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) { + if info.ShouldIncludeUsage { + containStreamUsage = true + } else { + shouldSendLastResp = false + } + } + } + if shouldSendLastResp { + service.StringData(c, lastStreamData) + } + // 计算token streamResp := "[" + strings.Join(streamItems, ",") + "]" switch info.RelayMode {