one-api/providers/bedrock/stream_reader.go
Buer b81808e839
feat: support amazon bedrock anthropic (#114)
* 🚧 WIP: bedrock

*  feat: support amazon bedrock anthropic
2024-03-18 16:00:35 +08:00

129 lines
3.1 KiB
Go

package bedrock
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi"
"github.com/aws/smithy-go"
)
type streamReader[T any] struct {
reader *bufio.Reader
response *http.Response
handlerPrefix requester.HandlerPrefix[T]
DataChan chan T
ErrChan chan error
}
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
go stream.processLines()
return stream.DataChan, stream.ErrChan
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() {
decode := eventstream.NewDecoder()
payloadBuf := make([]byte, 0*1024)
for {
payloadBuf = payloadBuf[0:0]
messgae, readErr := decode.Decode(stream.reader, payloadBuf)
if readErr != nil {
stream.ErrChan <- readErr
return
}
line, err := stream.deserializeEventMessage(&messgae)
if err != nil {
stream.ErrChan <- common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
return
}
stream.handlerPrefix(&line, stream.DataChan, stream.ErrChan)
if line == nil {
continue
}
if bytes.Equal(line, requester.StreamClosed) {
return
}
}
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}
func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) ([]byte, error) {
messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader)
if messageType == nil {
return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader)
}
switch messageType.String() {
case eventstreamapi.EventMessageType:
var v BedrockResponseStream
if err := json.Unmarshal(msg.Payload, &v); err != nil {
return nil, err
}
buffer, err := base64.StdEncoding.DecodeString(v.Bytes)
if err != nil {
return nil, err
}
return buffer, nil
case eventstreamapi.ExceptionMessageType:
exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader)
return nil, errors.New("Exception message :" + exceptionType.String())
case eventstreamapi.ErrorMessageType:
errorCode := "UnknownError"
errorMessage := errorCode
if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil {
errorCode = header.String()
}
if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil {
errorMessage = header.String()
}
return nil, &smithy.GenericAPIError{
Code: errorCode,
Message: errorMessage,
}
default:
return nil, errors.New("bedrock stream unknown error")
}
}
func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
// 如果返回的头是json格式 说明有错误
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
return nil, requester.HandleErrorResp(resp, requestErrorHandle)
}
stream := &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
}