diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index da000b58..c20827b0 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -104,13 +104,14 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { createdTime := helper.GetTimestamp() - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go index aff3e0cf..76b06f91 100644 --- a/relay/adaptor/aws/llama3/main.go +++ b/relay/adaptor/aws/llama3/main.go @@ -141,13 +141,14 @@ func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { createdTime := helper.GetTimestamp() - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), }