mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
🐛 fix: stream mode delay issue (#53)
This commit is contained in:
@@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response,
|
||||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return &streamReader[T]{
|
||||
stream := &streamReader[T]{
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
|
||||
DataChan: make(chan T),
|
||||
ErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// 设置请求体
|
||||
|
||||
@@ -3,16 +3,12 @@ package requester
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 流处理函数,判断依据如下:
|
||||
// 1.如果有错误信息,则直接返回错误信息
|
||||
// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response
|
||||
// 3.如果rawLine=nil 或者 response长度为0,则直接跳过
|
||||
// 4.如果以上条件都不满足,则返回response
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
|
||||
var StreamClosed = []byte("stream_closed")
|
||||
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error)
|
||||
|
||||
type streamable interface {
|
||||
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
||||
@@ -20,57 +16,48 @@ type streamable interface {
|
||||
}
|
||||
|
||||
type StreamReaderInterface[T streamable] interface {
|
||||
Recv() (*[]T, error)
|
||||
Recv() (<-chan T, <-chan error)
|
||||
Close()
|
||||
}
|
||||
|
||||
type streamReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
|
||||
DataChan chan T
|
||||
ErrChan chan error
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
|
||||
go stream.processLines()
|
||||
|
||||
return stream.DataChan, stream.ErrChan
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() (*[]T, error) {
|
||||
func (stream *streamReader[T]) processLines() {
|
||||
for {
|
||||
rawLine, readErr := stream.reader.ReadBytes('\n')
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
stream.ErrChan <- readErr
|
||||
return
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
|
||||
var response []T
|
||||
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if noSpaceLine == nil || len(response) == 0 {
|
||||
if len(noSpaceLine) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan)
|
||||
|
||||
if noSpaceLine == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(noSpaceLine, StreamClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,55 +1,41 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"io"
|
||||
"bytes"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type wsReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *websocket.Conn
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
|
||||
DataChan chan T
|
||||
ErrChan chan error
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) {
|
||||
go stream.processLines()
|
||||
return stream.DataChan, stream.ErrChan
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) processLines() (*[]T, error) {
|
||||
func (stream *wsReader[T]) processLines() {
|
||||
for {
|
||||
_, msg, err := stream.reader.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
stream.ErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
var response []T
|
||||
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
|
||||
stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if msg == nil || len(response) == 0 {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
if bytes.Equal(msg, StreamClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref
|
||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return &wsReader[T]{
|
||||
stream := &wsReader[T]{
|
||||
reader: conn,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
|
||||
DataChan: make(chan T),
|
||||
ErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
|
||||
Reference in New Issue
Block a user