From 9294127686633dbb59f54d7d753490c5c9f3dd78 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 23 Apr 2024 11:44:40 +0800 Subject: [PATCH] feat: support aws claude --- common/constants.go | 4 + controller/channel-test.go | 2 +- go.mod | 9 ++ go.sum | 31 ++-- model/channel.go | 7 +- relay/channel/aws/adaptor.go | 79 +++++++++ relay/channel/aws/constants.go | 12 ++ relay/channel/aws/dto.go | 14 ++ relay/channel/aws/relay-aws.go | 211 +++++++++++++++++++++++++ relay/channel/claude/adaptor.go | 4 +- relay/channel/claude/relay-claude.go | 17 +- relay/constant/api_type.go | 3 + relay/relay-text.go | 16 +- relay/relay_adaptor.go | 3 + web/src/constants/channel.constants.js | 7 + web/src/helpers/render.js | 19 ++- web/src/pages/Channel/EditChannel.js | 107 ++++++++----- 17 files changed, 464 insertions(+), 81 deletions(-) create mode 100644 relay/channel/aws/adaptor.go create mode 100644 relay/channel/aws/constants.go create mode 100644 relay/channel/aws/dto.go create mode 100644 relay/channel/aws/relay-aws.go diff --git a/common/constants.go b/common/constants.go index 78935ec..f0fb1d5 100644 --- a/common/constants.go +++ b/common/constants.go @@ -206,6 +206,7 @@ const ( ChannelTypeZhipu_v4 = 26 ChannelTypePerplexity = 27 ChannelTypeLingYiWanWu = 31 + ChannelTypeAws = 33 ) var ChannelBaseURLs = []string{ @@ -241,4 +242,7 @@ var ChannelBaseURLs = []string{ "", //29 "", //30 "https://api.lingyiwanwu.com", //31 + "", //32 + "", //33 + } diff --git a/controller/channel-test.go b/controller/channel-test.go index e407193..f66e0d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -86,7 +86,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if err != nil { return err, nil } - if resp.StatusCode != http.StatusOK { + if resp != nil && resp.StatusCode != http.StatusOK { err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } diff --git a/go.mod b/go.mod index 62bc80e..a7dc491 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,9 @@ go 1.18 require ( github.com/Calcium-Ion/go-epay v0.0.2 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -16,6 +19,8 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 + github.com/jinzhu/copier v0.4.0 + github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.6 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible @@ -29,6 +34,10 @@ require ( require ( github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index 5bb3189..c6e80ba 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,23 @@ -github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc= -github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc= github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -66,8 +78,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -87,6 +99,8 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -132,6 +146,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -139,14 +155,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= -github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI= -github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -179,8 +191,6 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= @@ -188,8 +198,6 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -212,7 +220,6 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/model/channel.go b/model/channel.go index c0c21c0..5b35851 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,9 +25,10 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(64);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` - Priority *int64 `json:"priority" gorm:"bigint;default:0"` - AutoBan *int `json:"auto_ban" gorm:"default:1"` + //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` + StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` + AutoBan *int `json:"auto_ban" gorm:"default:1"` } func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go new file mode 100644 index 0000000..23c69db --- /dev/null +++ b/relay/channel/aws/adaptor.go @@ -0,0 +1,79 @@ +package aws + +import ( + "errors" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel/claude" + relaycommon "one-api/relay/common" + "strings" +) + +const ( + RequestModeCompletion = 1 + RequestModeMessage = 2 +) + +type Adaptor struct { + RequestMode int +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + if strings.HasPrefix(info.UpstreamModelName, "claude-3") { + a.RequestMode = RequestModeMessage + } else { + a.RequestMode = RequestModeCompletion + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + var claudeReq *claude.ClaudeRequest + var err error + if a.RequestMode == RequestModeCompletion { + claudeReq = claude.RequestOpenAI2ClaudeComplete(*request) + } else { + claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) + } + c.Set("request_model", request.Model) + c.Set("converted_request", claudeReq) + return claudeReq, err +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = awsStreamHandler(c, info, a.RequestMode) + } else { + err, usage = awsHandler(c, info, a.RequestMode) + } + return +} + +func (a *Adaptor) GetModelList() (models []string) { + for n := range awsModelIDMap { + models = append(models, n) + } + + return +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go new file mode 100644 index 0000000..0b03785 --- /dev/null +++ b/relay/channel/aws/constants.go @@ -0,0 +1,12 @@ +package aws + +var awsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +} + +var ChannelName = "aws" diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go new file mode 100644 index 0000000..7450908 --- /dev/null +++ b/relay/channel/aws/dto.go @@ -0,0 +1,14 @@ +package aws + +import "one-api/relay/channel/claude" + +type AwsClaudeRequest struct { + // AnthropicVersion should be "bedrock-2023-05-31" + AnthropicVersion string `json:"anthropic_version"` + Messages []claude.ClaudeMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go new file mode 100644 index 0000000..bf64f03 --- /dev/null +++ b/relay/channel/aws/relay-aws.go @@ -0,0 +1,211 @@ +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/pkg/errors" + "io" + "net/http" + "one-api/common" + relaymodel "one-api/dto" + "one-api/relay/channel/claude" + relaycommon "one-api/relay/common" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" +) + +func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { + awsSecret := strings.Split(info.ApiKey, "|") + if len(awsSecret) != 3 { + return nil, errors.New("invalid aws secret key") + } + ak := awsSecret[0] + sk := awsSecret[1] + region := awsSecret[2] + client := bedrockruntime.New(bedrockruntime.Options{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), + }) + + return client, nil +} + +func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode { + return &relaymodel.OpenAIErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.OpenAIError{ + Message: fmt.Sprintf("%s", err.Error()), + }, + } +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := awsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { + awsCli, err := newAwsClient(c, info) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(c.GetString("request_model")) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get("converted_request") + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*claude.ClaudeRequest) + awsClaudeReq := &AwsClaudeRequest{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + claudeResponse := new(claude.ClaudeResponse) + err = json.Unmarshal(awsResp.Body, claudeResponse) + if err != nil { + return wrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { + awsCli, err := newAwsClient(c, info) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(c.GetString("request_model")) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get("converted_request") + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*claude.ClaudeRequest) + + awsClaudeReq := &AwsClaudeRequest{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + var id string + var model string + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + claudeResp := new(claude.ClaudeResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp) + if claudeUsage != nil { + usage.PromptTokens += claudeUsage.InputTokens + usage.CompletionTokens += claudeUsage.OutputTokens + } + + if response == nil { + return true + } + + if response.Id != "" { + id = response.Id + } + if response.Model != "" { + model = response.Model + } + response.Id = id + response.Model = model + + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 45efd01..9add208 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,9 +53,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return nil, errors.New("request is nil") } if a.RequestMode == RequestModeCompletion { - return requestOpenAI2ClaudeComplete(*request), nil + return RequestOpenAI2ClaudeComplete(*request), nil } else { - return requestOpenAI2ClaudeMessage(*request) + return RequestOpenAI2ClaudeMessage(*request) } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 2b5d3d2..3d99664 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -26,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { +func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -57,7 +57,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR return &claudeRequest } -func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { +func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { claudeRequest := ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -122,7 +122,7 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { var response dto.ChatCompletionsStreamResponse var claudeUsage *ClaudeUsage response.Object = "chat.completion.chunk" @@ -149,6 +149,8 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* choice.FinishReason = &finishReason } claudeUsage = &claudeResponse.Usage + } else if claudeResponse.Type == "message_stop" { + return nil, nil } } if claudeUsage == nil { @@ -158,7 +160,7 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* return &response, claudeUsage } -func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { +func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { choices := make([]dto.OpenAITextResponseChoice, 0) fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -242,7 +244,10 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c return true } - response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse) + response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + if response == nil { + return true + } if requestMode == RequestModeCompletion { responseText += claudeResponse.Completion responseId = response.Id @@ -317,7 +322,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT StatusCode: resp.StatusCode, }, nil } - fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) + fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 8e6f67e..8ee6a99 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -18,6 +18,7 @@ const ( APITypeZhipu_v4 APITypeOllama APITypePerplexity + APITypeAws APITypeDummy // this one is only for count, do not add any channel after this ) @@ -49,6 +50,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeOllama case common.ChannelTypePerplexity: apiType = APITypePerplexity + case common.ChannelTypeAws: + apiType = APITypeAws } return apiType } diff --git a/relay/relay-text.go b/relay/relay-text.go index 6026560..890f543 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -159,14 +159,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(resp) - // reset status code 重置状态码 - service.ResetStatusCode(openaiErr, statusCodeMappingStr) - return openaiErr + if resp != nil { + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } } usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index e6afab5..867fd53 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -3,6 +3,7 @@ package relay import ( "one-api/relay/channel" "one-api/relay/channel/ali" + "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" "one-api/relay/channel/gemini" @@ -45,6 +46,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &ollama.Adaptor{} case constant.APITypePerplexity: return &perplexity.Adaptor{} + case constant.APITypeAws: + return &aws.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 8fdfd1b..1322131 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -22,6 +22,13 @@ export const CHANNEL_OPTIONS = [ color: 'indigo', label: 'Anthropic Claude', }, + { + key: 33, + text: 'AWS Claude', + value: 33, + color: 'black', + label: 'AWS Claude', + }, { key: 3, text: 'Azure OpenAI', diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index b76b6c8..8cea432 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -164,24 +164,23 @@ const colors = [ export const modelColorMap = { 'dall-e': 'rgb(147,112,219)', // 深紫色 - 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调 + // 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调 'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调 - midjourney: 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调 'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色 - 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色 + // 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色 'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿 'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿 - 'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色 - 'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色 + 'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色 + 'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃色 'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色 'gpt-4': 'rgb(135,206,235)', // 天蓝色 - 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色 + // 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色 'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝 'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝 'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝 'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝 'gpt-4-32k': 'rgb(104,111,238)', // 中紫色 - 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色 + // 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色 'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色 'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝 'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色 @@ -189,7 +188,7 @@ export const modelColorMap = { 'text-ada-001': 'rgb(255,192,203)', // 粉红色 'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色 'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色 - 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色 + // 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色 'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色(与Curie相同,表示同一个系列) 'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色 'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红 @@ -201,6 +200,10 @@ export const modelColorMap = { 'tts-1-hd': 'rgb(255,215,0)', // 金色 'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别) 'whisper-1': 'rgb(245,245,220)', // 米色 + 'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色 + 'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色 + 'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色 + 'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别) }; export function stringToColor(str) { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index f93e4dd..8da2e30 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -22,6 +22,7 @@ import { Checkbox, Banner, } from '@douyinfe/semi-ui'; +import { Divider } from 'semantic-ui-react'; const MODEL_MAPPING_EXAMPLE = { 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', @@ -44,6 +45,8 @@ function type2secretPrompt(type) { return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; case 23: return '按照如下格式输入:AppId|SecretId|SecretKey'; + case 33: + return '按照如下格式输入:Ak|Sk|Region'; default: return '请输入渠道对应的鉴权密钥'; } @@ -62,6 +65,7 @@ const EditChannel = (props) => { type: 1, key: '', openai_organization: '', + max_input_tokens: 0, base_url: '', other: '', model_mapping: '', @@ -86,6 +90,7 @@ const EditChannel = (props) => { if (name === 'type' && inputs.models.length === 0) { let localModels = []; switch (value) { + case 33: case 14: localModels = [ 'claude-instant-1.2', @@ -641,36 +646,6 @@ const EditChannel = (props) => { > 填入模板 -