diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 5bf3551..9f4221a 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -7,9 +7,9 @@ import ( type AwsClaudeRequest struct { // AnthropicVersion should be "bedrock-2023-05-31" AnthropicVersion string `json:"anthropic_version"` - System string `json:"system"` + System string `json:"system,omitempty"` Messages []claude.ClaudeMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -17,3 +17,18 @@ type AwsClaudeRequest struct { Tools []claude.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` } + +func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { + return &AwsClaudeRequest{ + AnthropicVersion: "bedrock-2023-05-31", + System: req.System, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + StopSequences: req.StopSequences, + Tools: req.Tools, + ToolChoice: req.ToolChoice, + } +} diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 748a84e..1b0882b 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/jinzhu/copier" "github.com/pkg/errors" "io" "net/http" @@ -78,13 +77,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return wrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*claude.ClaudeRequest) - awsClaudeReq := &AwsClaudeRequest{ - AnthropicVersion: "bedrock-2023-05-31", - } - if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil - } - + awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { return wrapErr(errors.Wrap(err, "marshal request")), nil @@ -136,12 +129,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } claudeReq := claudeReq_.(*claude.ClaudeRequest) - awsClaudeReq := &AwsClaudeRequest{ - AnthropicVersion: "bedrock-2023-05-31", - } - if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil - } + awsClaudeReq := copyRequest(claudeReq) awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { return wrapErr(errors.Wrap(err, "marshal request")), nil