mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-05 08:13:43 +08:00
feat: add support for aws's cross region inferences
closes #2024, closes #2145
This commit is contained in:
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(c, *request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
|||||||
@@ -6,16 +6,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/image"
|
"github.com/songquanpeng/one-api/common/image"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
// isModelSupportThinking is used to check if the model supports extended thinking
|
||||||
|
func isModelSupportThinking(model string) bool {
|
||||||
|
if strings.Contains(model, "claude-3-7-sonnet") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
|
||||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@@ -65,6 +75,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
Tools: claudeTools,
|
Tools: claudeTools,
|
||||||
Thinking: textRequest.Thinking,
|
Thinking: textRequest.Thinking,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isModelSupportThinking(textRequest.Model) &&
|
||||||
|
c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
|
||||||
|
claudeRequest.Thinking = &model.Thinking{
|
||||||
|
Type: "enabled",
|
||||||
|
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isModelSupportThinking(textRequest.Model) &&
|
||||||
|
claudeRequest.Thinking != nil {
|
||||||
|
if claudeRequest.MaxTokens <= 1024 {
|
||||||
|
return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking")
|
||||||
|
}
|
||||||
|
|
||||||
|
// top_p must be nil when using extended thinking
|
||||||
|
claudeRequest.TopP = nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(claudeTools) > 0 {
|
if len(claudeTools) > 0 {
|
||||||
claudeToolChoice := struct {
|
claudeToolChoice := struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -145,7 +174,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
claudeMessage.Content = contents
|
claudeMessage.Content = contents
|
||||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||||
}
|
}
|
||||||
return &claudeRequest
|
return &claudeRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
c.Set(ctxkey.RequestModel, request.Model)
|
c.Set(ctxkey.RequestModel, request.Model)
|
||||||
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
||||||
return claudeReq, nil
|
return claudeReq, nil
|
||||||
|
|||||||
@@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
|
|||||||
75
relay/adaptor/aws/utils/consts.go
Normal file
75
relay/adaptor/aws/utils/consts.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CrossRegionInferences is a list of model IDs that support cross-region inference.
|
||||||
|
//
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||||
|
//
|
||||||
|
// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
|
||||||
|
var CrossRegionInferences = []string{
|
||||||
|
"us.amazon.nova-lite-v1:0",
|
||||||
|
"us.amazon.nova-micro-v1:0",
|
||||||
|
"us.amazon.nova-pro-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
"us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"us.anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"us.meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-1-70b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-1-8b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-3-70b-instruct-v1:0",
|
||||||
|
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"eu.amazon.nova-lite-v1:0",
|
||||||
|
"eu.amazon.nova-micro-v1:0",
|
||||||
|
"eu.amazon.nova-pro-v1:0",
|
||||||
|
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"eu.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"eu.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"eu.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"apac.amazon.nova-lite-v1:0",
|
||||||
|
"apac.amazon.nova-micro-v1:0",
|
||||||
|
"apac.amazon.nova-pro-v1:0",
|
||||||
|
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"apac.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
|
||||||
|
func ConvertModelID2CrossRegionProfile(model, region string) string {
|
||||||
|
var regionPrefix string
|
||||||
|
switch prefix := strings.Split(region, "-")[0]; prefix {
|
||||||
|
case "us", "eu":
|
||||||
|
regionPrefix = prefix
|
||||||
|
case "ap":
|
||||||
|
regionPrefix = "apac"
|
||||||
|
default:
|
||||||
|
// not supported, return original model
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
newModelID := regionPrefix + "." + model
|
||||||
|
if slices.Contains(CrossRegionInferences, newModelID) {
|
||||||
|
logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
|
||||||
|
return newModelID
|
||||||
|
}
|
||||||
|
|
||||||
|
// not found, return original model
|
||||||
|
return model
|
||||||
|
}
|
||||||
@@ -32,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
req := Request{
|
req := Request{
|
||||||
AnthropicVersion: anthropicVersion,
|
AnthropicVersion: anthropicVersion,
|
||||||
// Model: claudeReq.Model,
|
// Model: claudeReq.Model,
|
||||||
|
|||||||
Reference in New Issue
Block a user