diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 29545d3..c0d3df1 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -9,10 +9,12 @@ import ( "net/http" "one-api/common" "strings" + "sync" + "time" ) func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { - responseText := "" + var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -26,9 +28,16 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) + dataChan := make(chan string, 5) + stopChan := make(chan bool, 2) + defer close(stopChan) + defer close(dataChan) + var wg sync.WaitGroup + go func() { + wg.Add(1) + defer wg.Done() + var streamItems []string for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format @@ -40,29 +49,39 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } + streamItems = append(streamItems, data) + } + } + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch relayMode { + case RelayModeChatCompletions: + var streamResponses []ChatCompletionsStreamResponseSimple + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.Content) } } + case RelayModeCompletions: + var streamResponses []CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + } + if len(dataChan) > 0 { + // wait data out + time.Sleep(2 * time.Second) } stopChan <- true }() @@ -85,7 +104,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - return nil, responseText + wg.Wait() + return nil, responseTextBuilder.String() } func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {