mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: add support for aws's cross region inferences
closes #2024, closes #2145
This commit is contained in:
		@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
				
			|||||||
	if request == nil {
 | 
						if request == nil {
 | 
				
			||||||
		return nil, errors.New("request is nil")
 | 
							return nil, errors.New("request is nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return ConvertRequest(*request), nil
 | 
						return ConvertRequest(c, *request)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
					func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,16 +6,17 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
 | 
						"math"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/songquanpeng/one-api/common/render"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/common"
 | 
						"github.com/songquanpeng/one-api/common"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
						"github.com/songquanpeng/one-api/common/helper"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/common/image"
 | 
						"github.com/songquanpeng/one-api/common/image"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
						"github.com/songquanpeng/one-api/common/logger"
 | 
				
			||||||
 | 
						"github.com/songquanpeng/one-api/common/render"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
						"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
						"github.com/songquanpeng/one-api/relay/model"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -38,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
					// isModelSupportThinking is used to check if the model supports extended thinking
 | 
				
			||||||
 | 
					func isModelSupportThinking(model string) bool {
 | 
				
			||||||
 | 
						if strings.Contains(model, "claude-3-7-sonnet") {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
 | 
				
			||||||
	claudeTools := make([]Tool, 0, len(textRequest.Tools))
 | 
						claudeTools := make([]Tool, 0, len(textRequest.Tools))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tool := range textRequest.Tools {
 | 
						for _, tool := range textRequest.Tools {
 | 
				
			||||||
@@ -65,6 +75,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
				
			|||||||
		Tools:       claudeTools,
 | 
							Tools:       claudeTools,
 | 
				
			||||||
		Thinking:    textRequest.Thinking,
 | 
							Thinking:    textRequest.Thinking,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if isModelSupportThinking(textRequest.Model) &&
 | 
				
			||||||
 | 
							c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
 | 
				
			||||||
 | 
							claudeRequest.Thinking = &model.Thinking{
 | 
				
			||||||
 | 
								Type:         "enabled",
 | 
				
			||||||
 | 
								BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if isModelSupportThinking(textRequest.Model) &&
 | 
				
			||||||
 | 
							claudeRequest.Thinking != nil {
 | 
				
			||||||
 | 
							if claudeRequest.MaxTokens <= 1024 {
 | 
				
			||||||
 | 
								return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// top_p must be nil when using extended thinking
 | 
				
			||||||
 | 
							claudeRequest.TopP = nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(claudeTools) > 0 {
 | 
						if len(claudeTools) > 0 {
 | 
				
			||||||
		claudeToolChoice := struct {
 | 
							claudeToolChoice := struct {
 | 
				
			||||||
			Type string `json:"type"`
 | 
								Type string `json:"type"`
 | 
				
			||||||
@@ -145,7 +174,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
				
			|||||||
		claudeMessage.Content = contents
 | 
							claudeMessage.Content = contents
 | 
				
			||||||
		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
 | 
							claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &claudeRequest
 | 
						return &claudeRequest, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://docs.anthropic.com/claude/reference/messages-streaming
 | 
					// https://docs.anthropic.com/claude/reference/messages-streaming
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
				
			|||||||
		return nil, errors.New("request is nil")
 | 
							return nil, errors.New("request is nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	claudeReq := anthropic.ConvertRequest(*request)
 | 
						claudeReq, err := anthropic.ConvertRequest(c, *request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "convert request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.Set(ctxkey.RequestModel, request.Model)
 | 
						c.Set(ctxkey.RequestModel, request.Model)
 | 
				
			||||||
	c.Set(ctxkey.ConvertedRequest, claudeReq)
 | 
						c.Set(ctxkey.ConvertedRequest, claudeReq)
 | 
				
			||||||
	return claudeReq, nil
 | 
						return claudeReq, nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
 | 
					func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
 | 
				
			||||||
	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
 | 
						awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
 | 
							return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
 | 
				
			||||||
	awsReq := &bedrockruntime.InvokeModelInput{
 | 
						awsReq := &bedrockruntime.InvokeModelInput{
 | 
				
			||||||
		ModelId:     aws.String(awsModelId),
 | 
							ModelId:     aws.String(awsModelID),
 | 
				
			||||||
		Accept:      aws.String("application/json"),
 | 
							Accept:      aws.String("application/json"),
 | 
				
			||||||
		ContentType: aws.String("application/json"),
 | 
							ContentType: aws.String("application/json"),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
 | 
					func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
 | 
				
			||||||
	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
 | 
						awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
 | 
							return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
 | 
				
			||||||
	awsReq := &bedrockruntime.InvokeModelInput{
 | 
						awsReq := &bedrockruntime.InvokeModelInput{
 | 
				
			||||||
		ModelId:     aws.String(awsModelId),
 | 
							ModelId:     aws.String(awsModelID),
 | 
				
			||||||
		Accept:      aws.String("application/json"),
 | 
							Accept:      aws.String("application/json"),
 | 
				
			||||||
		ContentType: aws.String("application/json"),
 | 
							ContentType: aws.String("application/json"),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										75
									
								
								relay/adaptor/aws/utils/consts.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								relay/adaptor/aws/utils/consts.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"slices"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/songquanpeng/one-api/common/logger"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CrossRegionInferences is a list of model IDs that support cross-region inference.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
 | 
				
			||||||
 | 
					var CrossRegionInferences = []string{
 | 
				
			||||||
 | 
						"us.amazon.nova-lite-v1:0",
 | 
				
			||||||
 | 
						"us.amazon.nova-micro-v1:0",
 | 
				
			||||||
 | 
						"us.amazon.nova-pro-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-5-haiku-20241022-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-haiku-20240307-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-opus-20240229-v1:0",
 | 
				
			||||||
 | 
						"us.anthropic.claude-3-sonnet-20240229-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-1-405b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-1-70b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-1-8b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-2-11b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-2-1b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-2-3b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-2-90b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us.meta.llama3-3-70b-instruct-v1:0",
 | 
				
			||||||
 | 
						"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
				
			||||||
 | 
						"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
 | 
				
			||||||
 | 
						"eu.amazon.nova-lite-v1:0",
 | 
				
			||||||
 | 
						"eu.amazon.nova-micro-v1:0",
 | 
				
			||||||
 | 
						"eu.amazon.nova-pro-v1:0",
 | 
				
			||||||
 | 
						"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
				
			||||||
 | 
						"eu.anthropic.claude-3-haiku-20240307-v1:0",
 | 
				
			||||||
 | 
						"eu.anthropic.claude-3-sonnet-20240229-v1:0",
 | 
				
			||||||
 | 
						"eu.meta.llama3-2-1b-instruct-v1:0",
 | 
				
			||||||
 | 
						"eu.meta.llama3-2-3b-instruct-v1:0",
 | 
				
			||||||
 | 
						"apac.amazon.nova-lite-v1:0",
 | 
				
			||||||
 | 
						"apac.amazon.nova-micro-v1:0",
 | 
				
			||||||
 | 
						"apac.amazon.nova-pro-v1:0",
 | 
				
			||||||
 | 
						"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
				
			||||||
 | 
						"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
				
			||||||
 | 
						"apac.anthropic.claude-3-haiku-20240307-v1:0",
 | 
				
			||||||
 | 
						"apac.anthropic.claude-3-sonnet-20240229-v1:0",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
 | 
				
			||||||
 | 
					func ConvertModelID2CrossRegionProfile(model, region string) string {
 | 
				
			||||||
 | 
						var regionPrefix string
 | 
				
			||||||
 | 
						switch prefix := strings.Split(region, "-")[0]; prefix {
 | 
				
			||||||
 | 
						case "us", "eu":
 | 
				
			||||||
 | 
							regionPrefix = prefix
 | 
				
			||||||
 | 
						case "ap":
 | 
				
			||||||
 | 
							regionPrefix = "apac"
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							// not supported, return original model
 | 
				
			||||||
 | 
							return model
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						newModelID := regionPrefix + "." + model
 | 
				
			||||||
 | 
						if slices.Contains(CrossRegionInferences, newModelID) {
 | 
				
			||||||
 | 
							logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
 | 
				
			||||||
 | 
							return newModelID
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// not found, return original model
 | 
				
			||||||
 | 
						return model
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -32,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
				
			|||||||
		return nil, errors.New("request is nil")
 | 
							return nil, errors.New("request is nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	claudeReq := anthropic.ConvertRequest(*request)
 | 
						claudeReq, err := anthropic.ConvertRequest(c, *request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "convert request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := Request{
 | 
						req := Request{
 | 
				
			||||||
		AnthropicVersion: anthropicVersion,
 | 
							AnthropicVersion: anthropicVersion,
 | 
				
			||||||
		// Model:            claudeReq.Model,
 | 
							// Model:            claudeReq.Model,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user