one-api/relay/adaptor/aws/main.go
2024-04-18 02:52:26 +00:00

234 lines
6.6 KiB
Go

// Package aws provides the AWS adaptor for the relay service.
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
ks := strings.Split(channel.Key, "\n")
if len(ks) != 2 {
return nil, errors.New("invalid key")
}
ak, sk := ks[0], ks[1]
client := bedrockruntime.New(bedrockruntime.Options{
Region: *channel.BaseURL,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%+v", err),
},
}
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
func awsModelID(requestModel string) (string, error) {
switch requestModel {
case "claude-instant-1.2":
return "anthropic.claude-instant-v1", nil
case "claude-2.0":
return "anthropic.claude-v2", nil
case "claude-2.1":
return "anthropic.claude-v2:1", nil
case "claude-3-sonnet-20240229":
return "anthropic.claude-3-sonnet-20240229-v1:0", nil
case "claude-3-opus-20240229":
return "anthropic.claude-3-opus-20240229-v1:0", nil
case "claude-3-haiku-20240307":
return "anthropic.claude-3-haiku-20240307-v1:0", nil
default:
return "", errors.Errorf("unknown model: %s", requestModel)
}
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
openaiResp.Model = modelName
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(channel.Models)
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
var usage relaymodel.Usage
var id string
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
claudeResp := new(anthropic.StreamResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}