Merge pull request #20 from Calcium-Ion/optimize/hign--cpu

fix: 修复客户端中断请求,计算补全阻塞问题
This commit is contained in:
Calcium-Ion 2023-12-07 17:11:57 +08:00 committed by GitHub
commit e095900d88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,10 +9,12 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"strings" "strings"
"sync"
"time"
) )
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
responseText := "" var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -26,9 +28,16 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string) dataChan := make(chan string, 5)
stopChan := make(chan bool) stopChan := make(chan bool, 2)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() { go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
@ -40,29 +49,39 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
dataChan <- data dataChan <- data
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
switch relayMode { streamItems = append(streamItems, data)
case RelayModeChatCompletions: }
var streamResponse ChatCompletionsStreamResponseSimple }
err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse) streamResp := "[" + strings.Join(streamItems, ",") + "]"
if err != nil { switch relayMode {
common.SysError("error unmarshalling stream response: " + err.Error()) case RelayModeChatCompletions:
continue // just ignore the error var streamResponses []ChatCompletionsStreamResponseSimple
} err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
for _, choice := range streamResponse.Choices { if err != nil {
responseText += choice.Delta.Content common.SysError("error unmarshalling stream response: " + err.Error())
} return // just ignore the error
case RelayModeCompletions: }
var streamResponse CompletionsStreamResponse for _, streamResponse := range streamResponses {
err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse) for _, choice := range streamResponse.Choices {
if err != nil { responseTextBuilder.WriteString(choice.Delta.Content)
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
} }
} }
case RelayModeCompletions:
var streamResponses []CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
} }
stopChan <- true stopChan <- true
}() }()
@ -85,7 +104,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText wg.Wait()
return nil, responseTextBuilder.String()
} }
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {