From 402a415c7983a7f3123d80fffbe0b27746d3e013 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 27 Jun 2024 17:24:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E6=B5=81=E6=A8=A1=E5=BC=8F=E8=B6=85=E6=97=B6=E6=97=B6=E9=97=B4?= =?UTF-8?q?(gemini,=20claude)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/claude/relay-claude.go | 11 ++++++++--- relay/channel/gemini/relay-gemini.go | 11 ++++++++--- relay/channel/openai/relay-openai.go | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 3818af3..f34d9a7 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -267,8 +268,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) + dataChan := make(chan string, 5) + stopChan := make(chan bool, 2) go func() { for scanner.Scan() { data := scanner.Text() @@ -276,7 +277,11 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. continue } data = strings.TrimPrefix(data, "data: ") - dataChan <- data + if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { + // send data timeout, stop the stream + common.LogError(c, "send data timeout, stop the stream") + break + } } stopChan <- true }() diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 2ba5cef..8af08c5 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -163,8 +164,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" - dataChan := make(chan string) - stopChan := make(chan bool) + dataChan := make(chan string, 5) + stopChan := make(chan bool, 2) scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -187,7 +188,11 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } data = strings.TrimPrefix(data, "\"text\": \"") data = strings.TrimSuffix(data, "\"") - dataChan <- data + if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { + // send data timeout, stop the stream + common.LogError(c, "send data timeout, stop the stream") + break + } } stopChan <- true }() diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 5146a4f..5c2acf4 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -54,7 +54,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { // send data timeout, stop the stream - common.LogInfo(c, "send data timeout, stop the stream") + common.LogError(c, "send data timeout, stop the stream") break } data = data[6:]