feat: add support for aws's cross region inferences

closes #2024, closes #2145
This commit is contained in:
Laisky.Cai
2025-03-10 06:37:42 +00:00
parent c61d6440f9
commit de10e102bd
7 changed files with 125 additions and 11 deletions

View File

@@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
claudeReq := anthropic.ConvertRequest(*request)
claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil

View File

@@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) {
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}

View File

@@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
}
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}

View File

@@ -0,0 +1,75 @@
package utils
import (
"context"
"slices"
"strings"
"github.com/songquanpeng/one-api/common/logger"
)
// CrossRegionInferences is a list of model IDs that support cross-region inference.
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
//
// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
var CrossRegionInferences = []string{
"us.amazon.nova-lite-v1:0",
"us.amazon.nova-micro-v1:0",
"us.amazon.nova-pro-v1:0",
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"us.anthropic.claude-3-haiku-20240307-v1:0",
"us.anthropic.claude-3-opus-20240229-v1:0",
"us.anthropic.claude-3-sonnet-20240229-v1:0",
"us.meta.llama3-1-405b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
"us.meta.llama3-3-70b-instruct-v1:0",
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
"eu.amazon.nova-lite-v1:0",
"eu.amazon.nova-micro-v1:0",
"eu.amazon.nova-pro-v1:0",
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"eu.anthropic.claude-3-haiku-20240307-v1:0",
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
"eu.meta.llama3-2-1b-instruct-v1:0",
"eu.meta.llama3-2-3b-instruct-v1:0",
"apac.amazon.nova-lite-v1:0",
"apac.amazon.nova-micro-v1:0",
"apac.amazon.nova-pro-v1:0",
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"apac.anthropic.claude-3-haiku-20240307-v1:0",
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
}
// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
func ConvertModelID2CrossRegionProfile(model, region string) string {
var regionPrefix string
switch prefix := strings.Split(region, "-")[0]; prefix {
case "us", "eu":
regionPrefix = prefix
case "ap":
regionPrefix = "apac"
default:
// not supported, return original model
return model
}
newModelID := regionPrefix + "." + model
if slices.Contains(CrossRegionInferences, newModelID) {
logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
return newModelID
}
// not found, return original model
return model
}