Merge commit '2369025842b828ac38f4427fd1ebab8d03b1fe7f'

This commit is contained in:
Laisky.Cai
2024-04-20 01:07:29 +00:00
139 changed files with 2642 additions and 2625 deletions

View File

@@ -7,8 +7,14 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/Laisky/one-api/common"
"github.com/Laisky/one-api/common/config"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper"
"github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/relay/adaptor/anthropic"
relaymodel "github.com/Laisky/one-api/relay/model"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
@@ -16,23 +22,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
ks := strings.Split(channel.Key, "\n")
if len(ks) != 2 {
return nil, errors.New("invalid key")
}
ak, sk := ks[0], ks[1]
func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) {
ak := c.GetString(config.KeyAK)
sk := c.GetString(config.KeySK)
region := c.GetString(config.KeyRegion)
client := bedrockruntime.New(bedrockruntime.Options{
Region: *channel.BaseURL,
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
@@ -43,7 +40,7 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%+v", err),
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
@@ -67,19 +64,12 @@ func awsModelID(requestModel string) (string, error) {
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel))
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
@@ -90,11 +80,11 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
@@ -133,20 +123,12 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
var channel *model.Channel
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
return wrapErr(errors.New("channel not found")), nil
} else {
channel = channeli.(*model.Channel)
}
awsCli, err := newAwsClient(channel)
awsCli, err := newAwsClient(c)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString(common.CtxKeyRequestModel))
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
@@ -157,11 +139,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
ContentType: aws.String("application/json"),
}
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReqi.(*anthropic.Request)
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
@@ -211,7 +193,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithSt
return true
}
response.Id = id
response.Model = c.GetString(common.CtxKeyOriginModel)
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {