refactor: Refactor: Improve Anthropic event stream response handling

- Move `scanner.Split` instantiation to a new function
- Introduce a new regular expression to extract data from the response
- Utilize regular expressions to pre-process the event stream
This commit is contained in:
Laisky.Cai
2024-03-05 08:17:14 +00:00
parent fdde066252
commit bcd5cf3d5f

View File

@@ -4,15 +4,17 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"regexp"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@@ -84,34 +86,40 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
return &fullTextResponse return &fullTextResponse
} }
var dataRegexp = regexp.MustCompile(`^data: (\{.*\})\B`)
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { // if atEOF && len(data) == 0 {
return 0, nil, nil // return 0, nil, nil
} // }
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { // if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil // return i + 4, data[0:i], nil
} // }
if atEOF { // if atEOF {
return len(data), data, nil // return len(data), data, nil
} // }
return 0, nil, nil // return 0, nil, nil
}) // })
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := strings.TrimSpace(scanner.Text())
if !strings.HasPrefix(data, "event: completion") { // logger.SysLog(fmt.Sprintf("stream response: %s", data))
continue
matched := dataRegexp.FindAllStringSubmatch(data, -1)
for _, match := range matched {
data = match[1]
// logger.SysLog(fmt.Sprintf("chunk response: %s", data))
dataChan <- data
} }
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
} }
stopChan <- true stopChan <- true
}() }()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)