refactor: 重构流模式逻辑

This commit is contained in:
CalciumIon 2024-07-15 18:04:05 +08:00
parent 0f687aab9a
commit 7029065892
4 changed files with 114 additions and 123 deletions

View File

@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`

View File

@ -14,7 +14,6 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@ -14,38 +14,33 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
hasStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
var responseTextBuilder strings.Builder
var usage dto.Usage
var usage = &dto.Usage{}
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
scanner.Split(bufio.ScanLines)
var streamItems []string // store stream items
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
stopChan := make(chan bool, 2)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string // store stream items
for scanner.Scan() {
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
@ -53,33 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
service.StringData(c, data)
streamItems = append(streamItems, data)
}
}
stopChan <- true
}()
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@ -97,10 +101,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@ -139,44 +145,25 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
if !hasStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
if info.ShouldIncludeUsage && !hasStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
wg.Wait()
return nil, &usage, responseTextBuilder.String(), toolCount
return nil, usage, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
}