🐛 fix: stream mode delay issue (#53)

This commit is contained in:
Buer
2024-01-25 11:56:31 +08:00
committed by GitHub
parent 705804e6dd
commit d7193b8e46
20 changed files with 291 additions and 262 deletions

View File

@@ -2,6 +2,7 @@ package baidu
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
@@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques
return p.convertToChatOpenai(baiduResponse, request)
}
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getBaiduChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque
}
// 转换为OpenAI聊天流式请求体
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal(*rawLine, &baiduResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&baiduResponse.BaiduError)
if error != nil {
errChan <- error
return
}
h.convertToOpenaiStream(&baiduResponse, dataChan, errChan)
if baiduResponse.IsEnd {
*isFinished = true
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
return h.convertToOpenaiStream(&baiduResponse, response)
}
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) {
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
@@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea
if baiduResponse.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
}
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
return nil
}