diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index 04260723..a0fa7f98 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -39,6 +39,15 @@ func stopReasonClaude2OpenAI(reason *string) string { } } +// isModelSupportThinking is used to check if the model supports extended thinking +func isModelSupportThinking(model string) bool { + if strings.Contains(model, "claude-3-7-sonnet") { + return true + } + + return false +} + func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) { claudeTools := make([]Tool, 0, len(textRequest.Tools)) @@ -67,14 +76,16 @@ func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Re Thinking: textRequest.Thinking, } - if c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { + if isModelSupportThinking(textRequest.Model) && + c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { claudeRequest.Thinking = &model.Thinking{ Type: "enabled", BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))), } } - if claudeRequest.Thinking != nil { + if isModelSupportThinking(textRequest.Model) && + claudeRequest.Thinking != nil { if claudeRequest.MaxTokens <= 1024 { return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking") } diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 69251c3d..762f5ba6 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go index e5fcd89f..aff3e0cf 100644 --- a/relay/adaptor/aws/llama3/main.go +++ b/relay/adaptor/aws/llama3/main.go @@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/utils/consts.go b/relay/adaptor/aws/utils/consts.go new file mode 100644 index 00000000..c91f342e --- /dev/null +++ b/relay/adaptor/aws/utils/consts.go @@ -0,0 +1,75 @@ +package utils + +import ( + "context" + "slices" + "strings" + + "github.com/songquanpeng/one-api/common/logger" +) + +// CrossRegionInferences is a list of model IDs that support cross-region inference. +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +// +// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)}) +var CrossRegionInferences = []string{ + "us.amazon.nova-lite-v1:0", + "us.amazon.nova-micro-v1:0", + "us.amazon.nova-pro-v1:0", + "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "us.anthropic.claude-3-haiku-20240307-v1:0", + "us.anthropic.claude-3-opus-20240229-v1:0", + "us.anthropic.claude-3-sonnet-20240229-v1:0", + "us.meta.llama3-1-405b-instruct-v1:0", + "us.meta.llama3-1-70b-instruct-v1:0", + "us.meta.llama3-1-8b-instruct-v1:0", + "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-1b-instruct-v1:0", + "us.meta.llama3-2-3b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0", + "us.meta.llama3-3-70b-instruct-v1:0", + "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us-gov.anthropic.claude-3-haiku-20240307-v1:0", + "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-micro-v1:0", + "eu.amazon.nova-pro-v1:0", + "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", + "eu.anthropic.claude-3-haiku-20240307-v1:0", + "eu.anthropic.claude-3-sonnet-20240229-v1:0", + "eu.meta.llama3-2-1b-instruct-v1:0", + "eu.meta.llama3-2-3b-instruct-v1:0", + "apac.amazon.nova-lite-v1:0", + "apac.amazon.nova-micro-v1:0", + "apac.amazon.nova-pro-v1:0", + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "apac.anthropic.claude-3-5-sonnet-20241022-v2:0", + "apac.anthropic.claude-3-haiku-20240307-v1:0", + "apac.anthropic.claude-3-sonnet-20240229-v1:0", +} + +// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID. +func ConvertModelID2CrossRegionProfile(model, region string) string { + var regionPrefix string + switch prefix := strings.Split(region, "-")[0]; prefix { + case "us", "eu": + regionPrefix = prefix + case "ap": + regionPrefix = "apac" + default: + // not supported, return original model + return model + } + + newModelID := regionPrefix + "." + model + if slices.Contains(CrossRegionInferences, newModelID) { + logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID) + return newModelID + } + + // not found, return original model + return model +}