diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index fe4e2ef0..7acee1f1 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + return ConvertRequest(c, *request) } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index b906e68d..65aaefa0 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -6,16 +6,17 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" ) @@ -38,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string { } } -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { +// isModelSupportThinking is used to check if the model supports extended thinking +func isModelSupportThinking(model string) bool { + if strings.Contains(model, "claude-3-7-sonnet") { + return true + } + + return false +} + +func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) { claudeTools := make([]Tool, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -65,6 +75,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { Tools: claudeTools, Thinking: textRequest.Thinking, } + + if isModelSupportThinking(textRequest.Model) && + c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { + claudeRequest.Thinking = &model.Thinking{ + Type: "enabled", + BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))), + } + } + + if isModelSupportThinking(textRequest.Model) && + claudeRequest.Thinking != nil { + if claudeRequest.MaxTokens <= 1024 { + return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking") + } + + // top_p must be nil when using extended thinking + claudeRequest.TopP = nil + } + if len(claudeTools) > 0 { claudeToolChoice := struct { Type string `json:"type"` @@ -145,7 +174,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { claudeMessage.Content = contents claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) } - return &claudeRequest + return &claudeRequest, nil } // https://docs.anthropic.com/claude/reference/messages-streaming diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go index eb3c9fb8..2f6a4cc5 100644 --- a/relay/adaptor/aws/claude/adapter.go +++ b/relay/adaptor/aws/claude/adapter.go @@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + c.Set(ctxkey.RequestModel, request.Model) c.Set(ctxkey.ConvertedRequest, claudeReq) return claudeReq, nil diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 378acdda..da000b58 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go index e5fcd89f..aff3e0cf 100644 --- a/relay/adaptor/aws/llama3/main.go +++ b/relay/adaptor/aws/llama3/main.go @@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/utils/consts.go b/relay/adaptor/aws/utils/consts.go new file mode 100644 index 00000000..c91f342e --- /dev/null +++ b/relay/adaptor/aws/utils/consts.go @@ -0,0 +1,75 @@ +package utils + +import ( + "context" + "slices" + "strings" + + "github.com/songquanpeng/one-api/common/logger" +) + +// CrossRegionInferences is a list of model IDs that support cross-region inference. +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +// +// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)}) +var CrossRegionInferences = []string{ + "us.amazon.nova-lite-v1:0", + "us.amazon.nova-micro-v1:0", + "us.amazon.nova-pro-v1:0", + "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "us.anthropic.claude-3-haiku-20240307-v1:0", + "us.anthropic.claude-3-opus-20240229-v1:0", + "us.anthropic.claude-3-sonnet-20240229-v1:0", + "us.meta.llama3-1-405b-instruct-v1:0", + "us.meta.llama3-1-70b-instruct-v1:0", + "us.meta.llama3-1-8b-instruct-v1:0", + "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-1b-instruct-v1:0", + "us.meta.llama3-2-3b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0", + "us.meta.llama3-3-70b-instruct-v1:0", + "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us-gov.anthropic.claude-3-haiku-20240307-v1:0", + "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-micro-v1:0", + "eu.amazon.nova-pro-v1:0", + "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", + "eu.anthropic.claude-3-haiku-20240307-v1:0", + "eu.anthropic.claude-3-sonnet-20240229-v1:0", + "eu.meta.llama3-2-1b-instruct-v1:0", + "eu.meta.llama3-2-3b-instruct-v1:0", + "apac.amazon.nova-lite-v1:0", + "apac.amazon.nova-micro-v1:0", + "apac.amazon.nova-pro-v1:0", + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "apac.anthropic.claude-3-5-sonnet-20241022-v2:0", + "apac.anthropic.claude-3-haiku-20240307-v1:0", + "apac.anthropic.claude-3-sonnet-20240229-v1:0", +} + +// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID. +func ConvertModelID2CrossRegionProfile(model, region string) string { + var regionPrefix string + switch prefix := strings.Split(region, "-")[0]; prefix { + case "us", "eu": + regionPrefix = prefix + case "ap": + regionPrefix = "apac" + default: + // not supported, return original model + return model + } + + newModelID := regionPrefix + "." + model + if slices.Contains(CrossRegionInferences, newModelID) { + logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID) + return newModelID + } + + // not found, return original model + return model +} diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index 554ada45..f591e447 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -32,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + req := Request{ AnthropicVersion: anthropicVersion, // Model: claudeReq.Model,