From 7029065892142e8d240477dd888f297ba018ac94 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 18:04:05 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=B5=81?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/text_response.go | 15 +- relay/channel/openai/adaptor.go | 9 +- relay/channel/openai/relay-openai.go | 209 +++++++++++++-------------- service/usage_helpr.go | 4 + 4 files changed, 114 insertions(+), 123 deletions(-) diff --git a/dto/text_response.go b/dto/text_response.go index 3310d02..e1f0cc0 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool { - return c.Content == nil && len(c.ToolCalls) == 0 -} - func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { c.Content = &s } @@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { + if c.SystemFingerprint == nil { + return "" + } + return *c.SystemFingerprint +} + +func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) { + c.SystemFingerprint = &s +} + type ChatCompletionsStreamResponseSimple struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Usage *Usage `json:"usage"` diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index e327027..688dedc 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -14,7 +14,6 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" - "one-api/service" "strings" ) @@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - var toolCount int - err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 - } + err, usage, _, _ = OpenaiStreamHandler(c, resp, info) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index dace39c..3fd7f03 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -14,38 +14,33 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strings" - "sync" "time" ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { - //checkSensitive := constant.ShouldCheckCompletionSensitive() + hasStreamUsage := false + responseId := "" + var createAt int64 = 0 + var systemFingerprint string + var responseTextBuilder strings.Builder - var usage dto.Usage + var usage = &dto.Usage{} toolCount := 0 scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string, 5) + scanner.Split(bufio.ScanLines) + var streamItems []string // store stream items + + service.SetEventStreamHeaders(c) + + ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) + defer ticker.Stop() + stopChan := make(chan bool, 2) defer close(stopChan) - defer close(dataChan) - var wg sync.WaitGroup + go func() { - wg.Add(1) - defer wg.Done() - var streamItems []string // store stream items for scanner.Scan() { + ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format continue @@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. if data[:6] != "data: " && data[:6] != "[DONE]" { continue } - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } data = data[6:] if !strings.HasPrefix(data, "[DONE]") { + service.StringData(c, data) streamItems = append(streamItems, data) } } - // 计算token - streamResp := "[" + strings.Join(streamItems, ",") + "]" - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - var streamResponses []dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } - } - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - } else { - for _, streamResponse := range streamResponses { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } + stopChan <- true + }() + + select { + case <-ticker.C: + // 超时处理逻辑 + common.LogError(c, "streaming timeout") + case <-stopChan: + // 正常结束 + } + + // 计算token + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch info.RelayMode { + case relayconstant.RelayModeChatCompletions: + var streamResponses []dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + hasStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } } - case relayconstant.RelayModeCompletions: - var streamResponses []dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) + } else { + for _, streamResponse := range streamResponses { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + hasStreamUsage = true + } + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) } } } - } else { - for _, streamResponse := range streamResponses { + } + } + case relayconstant.RelayModeCompletions: + var streamResponses []dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Text) } } } - } - if len(dataChan) > 0 { - // wait data out - time.Sleep(2 * time.Second) - } - common.SafeSendBool(stopChan, true) - }() - service.SetEventStreamHeaders(c) - isFirst := true - ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) - defer ticker.Stop() - c.Stream(func(w io.Writer) bool { - select { - case <-ticker.C: - common.LogError(c, "reading data from upstream timeout") - return false - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() + } else { + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } } - ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false } - }) + } + + if !hasStreamUsage { + usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } + + if info.ShouldIncludeUsage && !hasStreamUsage { + response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage) + response.SetSystemFingerprint(systemFingerprint) + service.ObjectData(c, response) + } + + service.Done(c) + err := resp.Body.Close() if err != nil { return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount } - wg.Wait() - return nil, &usage, responseTextBuilder.String(), toolCount + return nil, usage, responseTextBuilder.String(), toolCount } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 528f3d4..adec566 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d Usage: &usage, } } + +func ValidUsage(usage *dto.Usage) bool { + return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0) +}