feat: add support for aws's cross region inferences

closes #2024, closes #2145
This commit is contained in:
Laisky.Cai 2025-03-10 06:37:42 +00:00
parent 92774271b3
commit 5296a588b1
4 changed files with 94 additions and 6 deletions

View File

@ -39,6 +39,15 @@ func stopReasonClaude2OpenAI(reason *string) string {
} }
} }
// isModelSupportThinking is used to check if the model supports extended thinking
func isModelSupportThinking(model string) bool {
if strings.Contains(model, "claude-3-7-sonnet") {
return true
}
return false
}
func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) { func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools)) claudeTools := make([]Tool, 0, len(textRequest.Tools))
@ -67,14 +76,16 @@ func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Re
Thinking: textRequest.Thinking, Thinking: textRequest.Thinking,
} }
if c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { if isModelSupportThinking(textRequest.Model) &&
c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
claudeRequest.Thinking = &model.Thinking{ claudeRequest.Thinking = &model.Thinking{
Type: "enabled", Type: "enabled",
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))), BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
} }
} }
if claudeRequest.Thinking != nil { if isModelSupportThinking(textRequest.Model) &&
claudeRequest.Thinking != nil {
if claudeRequest.MaxTokens <= 1024 { if claudeRequest.MaxTokens <= 1024 {
return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking") return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking")
} }

View File

@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) {
} }
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }

View File

@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
} }
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }

View File

@ -0,0 +1,75 @@
package utils
import (
"context"
"slices"
"strings"
"github.com/songquanpeng/one-api/common/logger"
)
// CrossRegionInferences is a list of model IDs that support cross-region inference.
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
//
// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
var CrossRegionInferences = []string{
"us.amazon.nova-lite-v1:0",
"us.amazon.nova-micro-v1:0",
"us.amazon.nova-pro-v1:0",
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"us.anthropic.claude-3-haiku-20240307-v1:0",
"us.anthropic.claude-3-opus-20240229-v1:0",
"us.anthropic.claude-3-sonnet-20240229-v1:0",
"us.meta.llama3-1-405b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
"us.meta.llama3-3-70b-instruct-v1:0",
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
"eu.amazon.nova-lite-v1:0",
"eu.amazon.nova-micro-v1:0",
"eu.amazon.nova-pro-v1:0",
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"eu.anthropic.claude-3-haiku-20240307-v1:0",
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
"eu.meta.llama3-2-1b-instruct-v1:0",
"eu.meta.llama3-2-3b-instruct-v1:0",
"apac.amazon.nova-lite-v1:0",
"apac.amazon.nova-micro-v1:0",
"apac.amazon.nova-pro-v1:0",
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"apac.anthropic.claude-3-haiku-20240307-v1:0",
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
}
// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
func ConvertModelID2CrossRegionProfile(model, region string) string {
var regionPrefix string
switch prefix := strings.Split(region, "-")[0]; prefix {
case "us", "eu":
regionPrefix = prefix
case "ap":
regionPrefix = "apac"
default:
// not supported, return original model
return model
}
newModelID := regionPrefix + "." + model
if slices.Contains(CrossRegionInferences, newModelID) {
logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
return newModelID
}
// not found, return original model
return model
}