feat: support aws bedrockruntime claude3

closes #622, closes #749, closes #1300
This commit is contained in:
Laisky.Cai
2024-04-18 02:50:12 +00:00
parent b638a2fcbd
commit 4e1bfe4879
19 changed files with 458 additions and 21 deletions

View File

@@ -0,0 +1,78 @@
package aws
import (
"io"
"net/http"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
c.Set(common.CtxKeyRequestModel, request.Model)
c.Set(common.CtxKeyRawRequest, request)
c.Set(common.CtxKeyConvertedRequest, claudeReq)
return claudeReq, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var ModelList = []string{
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aws"
}

233
relay/adaptor/aws/main.go Normal file
View File

@@ -0,0 +1,233 @@
// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
ks := strings.Split(channel.Key, "\n")
if len(ks) != 2 {
return nil, errors.New("invalid key")
}
ak, sk := ks[0], ks[1]
client := bedrockruntime.New(bedrockruntime.Options{
Region: *channel.BaseURL,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%+v", err),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
func awsModelID(requestModel string) (string, error) {
switch requestModel {
case "claude-instant-1.2":
return "anthropic.claude-instant-v1", nil
case "claude-2.0":
return "anthropic.claude-v2", nil
case "claude-2.1":
return "anthropic.claude-v2:1", nil
case "claude-3-sonnet-20240229":
return "anthropic.claude-3-sonnet-20240229-v1:0", nil
case "claude-3-opus-20240229":
return "anthropic.claude-3-opus-20240229-v1:0", nil
case "claude-3-haiku-20240307":
return "anthropic.claude-3-haiku-20240307-v1:0", nil
default:
return "", errors.Errorf("unknown model: %s", requestModel)
}
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
var usage relaymodel.Usage
var id string
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
claudeResp := new(anthropic.StreamResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}

View File

@@ -0,0 +1,17 @@
package aws
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
// Request is the request to AWS Claude
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
}