refactor: 重构流模式逻辑

This commit is contained in:
CalciumIon 2024-07-15 18:04:05 +08:00
parent 0f687aab9a
commit 7029065892
4 changed files with 114 additions and 123 deletions

View File

@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s c.Content = &s
} }
@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct { type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`

View File

@ -14,7 +14,6 @@ import (
"one-api/relay/channel/minimax" "one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
) )
@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
var toolCount int
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else { } else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -14,38 +14,33 @@ import (
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strings"
"sync"
"time" "time"
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive() hasStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var usage dto.Usage var usage = &dto.Usage{}
toolCount := 0 toolCount := 0
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 { var streamItems []string // store stream items
return 0, nil, nil
} service.SetEventStreamHeaders(c)
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
} defer ticker.Stop()
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2) stopChan := make(chan bool, 2)
defer close(stopChan) defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() { go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string // store stream items
for scanner.Scan() { for scanner.Scan() {
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
continue continue
@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:6] != "data: " && data[:6] != "[DONE]" {
continue continue
} }
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
service.StringData(c, data)
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
} }
// 计算token stopChan <- true
streamResp := "[" + strings.Join(streamItems, ",") + "]" }()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions: select {
var streamResponses []dto.ChatCompletionsStreamResponseSimple case <-ticker.C:
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) // 超时处理逻辑
if err != nil { common.LogError(c, "streaming timeout")
// 一次性解析失败,逐个解析 case <-stopChan:
common.SysError("error unmarshalling stream response: " + err.Error()) // 正常结束
for _, item := range streamItems { }
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) // 计算token
if err == nil { streamResp := "[" + strings.Join(streamItems, ",") + "]"
if streamResponse.Usage != nil { switch info.RelayMode {
if streamResponse.Usage.TotalTokens != 0 { case relayconstant.RelayModeChatCompletions:
usage = *streamResponse.Usage var streamResponses []dto.ChatCompletionsStreamResponse
} err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
} if err != nil {
for _, choice := range streamResponse.Choices { // 一次性解析失败,逐个解析
responseTextBuilder.WriteString(choice.Delta.GetContentString()) common.SysError("error unmarshalling stream response: " + err.Error())
if choice.Delta.ToolCalls != nil { for _, item := range streamItems {
if len(choice.Delta.ToolCalls) > toolCount { var streamResponse dto.ChatCompletionsStreamResponse
toolCount = len(choice.Delta.ToolCalls) err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
} if err == nil {
for _, tool := range choice.Delta.ToolCalls { responseId = streamResponse.Id
responseTextBuilder.WriteString(tool.Function.Name) createAt = streamResponse.Created
responseTextBuilder.WriteString(tool.Function.Arguments) systemFingerprint = streamResponse.GetSystemFingerprint()
} if service.ValidUsage(streamResponse.Usage) {
} usage = streamResponse.Usage
} hasStreamUsage = true
}
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
} }
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
} }
} }
case relayconstant.RelayModeCompletions: } else {
var streamResponses []dto.CompletionsStreamResponse for _, streamResponse := range streamResponses {
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) responseId = streamResponse.Id
if err != nil { createAt = streamResponse.Created
// 一次性解析失败,逐个解析 systemFingerprint = streamResponse.GetSystemFingerprint()
common.SysError("error unmarshalling stream response: " + err.Error()) if service.ValidUsage(streamResponse.Usage) {
for _, item := range streamItems { usage = streamResponse.Usage
var streamResponse dto.CompletionsStreamResponse hasStreamUsage = true
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) }
if err == nil { for _, choice := range streamResponse.Choices {
for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Text) if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
} }
} }
} }
} else { }
for _, streamResponse := range streamResponses { }
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text) responseTextBuilder.WriteString(choice.Text)
} }
} }
} }
} } else {
if len(dataChan) > 0 { for _, streamResponse := range streamResponses {
// wait data out for _, choice := range streamResponse.Choices {
time.Sleep(2 * time.Second) responseTextBuilder.WriteString(choice.Text)
} }
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
} }
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
} }
}) }
if !hasStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !hasStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
service.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
} }
wg.Wait() return nil, usage, responseTextBuilder.String(), toolCount
return nil, &usage, responseTextBuilder.String(), toolCount
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
Usage: &usage, Usage: &usage,
} }
} }
func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
}