fix: openai stream response

This commit is contained in:
CalciumIon
2024-07-15 19:06:13 +08:00
parent 220ab412e2
commit e2b9061650

View File

@@ -22,20 +22,22 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId := "" responseId := ""
var createAt int64 = 0 var createAt int64 = 0
var systemFingerprint string var systemFingerprint string
model := info.UpstreamModelName
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var usage = &dto.Usage{} var usage = &dto.Usage{}
var streamItems []string // store stream items
toolCount := 0 toolCount := 0
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
var streamItems []string // store stream items
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop() defer ticker.Stop()
stopChan := make(chan bool, 2) stopChan := make(chan bool)
defer close(stopChan) defer close(stopChan)
go func() { go func() {
@@ -55,7 +57,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
} }
stopChan <- true common.SafeSendBool(stopChan, true)
}() }()
select { select {
@@ -82,6 +84,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId = streamResponse.Id responseId = streamResponse.Id
createAt = streamResponse.Created createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint() systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) { if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage usage = streamResponse.Usage
hasStreamUsage = true hasStreamUsage = true
@@ -105,6 +108,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseId = streamResponse.Id responseId = streamResponse.Id
createAt = streamResponse.Created createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint() systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) { if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage usage = streamResponse.Usage
hasStreamUsage = true hasStreamUsage = true
@@ -153,7 +157,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
if info.ShouldIncludeUsage && !hasStreamUsage { if info.ShouldIncludeUsage && !hasStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage) response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint) response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response) service.ObjectData(c, response)
} }
@@ -162,7 +166,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount common.LogError(c, "close_response_body_failed: "+err.Error())
} }
return nil, usage, responseTextBuilder.String(), toolCount return nil, usage, responseTextBuilder.String(), toolCount
} }