diff --git a/README.md b/README.md index b561cda..b35edb7 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ 8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md) 10. Dify +11. Vertex AI,目前兼容Claude,Gemini,Llama3.1 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 diff --git a/common/constants.go b/common/constants.go index d63955e..51144fd 100644 --- a/common/constants.go +++ b/common/constants.go @@ -214,6 +214,7 @@ const ( ChannelTypeJina = 38 ChannelCloudflare = 39 ChannelTypeSiliconFlow = 40 + ChannelTypeVertexAi = 41 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -261,4 +262,5 @@ var ChannelBaseURLs = []string{ "https://api.jina.ai", //38 "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 + "", //41 } diff --git a/controller/channel.go b/controller/channel.go index d723ee6..65ef721 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -198,6 +198,9 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") + if channel.Type == common.ChannelTypeVertexAi { + keys = []string{channel.Key} + } channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { diff --git a/go.mod b/go.mod index f97217b..4277a88 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,9 @@ module one-api // +heroku goVersion go1.18 -go 1.18 +go 1.21 + +toolchain go1.22.4 require ( github.com/Calcium-Ion/go-epay v0.0.2 @@ -9,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 + github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -24,7 +27,7 @@ require ( github.com/pkoukk/tiktoken-go v0.1.7 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible - golang.org/x/crypto v0.21.0 + golang.org/x/crypto v0.26.0 golang.org/x/image v0.15.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 @@ -38,9 +41,8 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect - github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect github.com/bytedance/sonic v1.9.1 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -51,6 +53,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect @@ -69,6 +72,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -76,10 +80,10 @@ require ( github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f19b88c..fedd195 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= @@ -37,6 +37,7 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= @@ -57,6 +58,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -81,7 +83,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -142,8 +145,11 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= @@ -172,7 +178,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= @@ -191,18 +198,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -214,26 +221,27 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/middleware/distributor.go b/middleware/distributor.go index 1be3b31..3ca5b8f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -199,6 +199,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeVertexAi: + c.Set("region", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) case common.ChannelTypeGemini: diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 0544695..b9173af 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = claudeStreamHandler(c, resp, info, a.RequestMode) + err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = ClaudeHandler(c, resp, a.RequestMode, info) } return } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1c9d4e6..1923e35 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -346,7 +346,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope return &fullTextResponse } -func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) var usage *dto.Usage usage = &dto.Usage{} @@ -428,7 +428,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return nil, usage } -func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -454,15 +454,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, model) + completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } usage := dto.Usage{} if requestMode == RequestModeCompletion { - usage.PromptTokens = promptTokens + usage.PromptTokens = info.PromptTokens usage.CompletionTokens = completionTokens - usage.TotalTokens = promptTokens + completionTokens + usage.TotalTokens = info.PromptTokens + completionTokens } else { usage.PromptTokens = claudeResponse.Usage.InputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4c4649f..07cdcfa 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -70,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = geminiChatStreamHandler(c, resp, info) + err, usage = GeminiChatStreamHandler(c, resp, info) } else { - err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = GeminiChatHandler(c, resp) } return } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 185582d..f6dba5e 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() @@ -279,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom return nil, usage } -func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go new file mode 100644 index 0000000..e3b4782 --- /dev/null +++ b/relay/channel/vertex/adaptor.go @@ -0,0 +1,183 @@ +package vertex + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/claude" + "one-api/relay/channel/gemini" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + "strings" +) + +const ( + RequestModeClaude = 1 + RequestModeGemini = 2 + RequestModeLlama = 3 +) + +var claudeModelMap = map[string]string{ + "claude-3-sonnet-20240229": "claude-3-sonnet@20240229", + "claude-3-opus-20240229": "claude-3-opus@20240229", + "claude-3-haiku-20240307": "claude-3-haiku@20240307", + "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", +} + +const anthropicVersion = "vertex-2023-10-16" + +type Adaptor struct { + RequestMode int + AccountCredentials Credentials +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + if strings.HasPrefix(info.UpstreamModelName, "claude") { + a.RequestMode = RequestModeClaude + } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { + a.RequestMode = RequestModeGemini + } else if strings.Contains(info.UpstreamModelName, "llama") { + a.RequestMode = RequestModeLlama + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + adc := &Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials file: %w", err) + } + a.AccountCredentials = *adc + suffix := "" + if a.RequestMode == RequestModeGemini { + if info.IsStream { + suffix = "streamGenerateContent?alt=sse" + } else { + suffix = "generateContent" + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + info.ApiVersion, + adc.ProjectID, + info.ApiVersion, + info.UpstreamModelName, + suffix, + ), nil + } else if a.RequestMode == RequestModeClaude { + if info.IsStream { + suffix = "streamRawPredict?alt=sse" + } else { + suffix = "rawPredict" + } + if v, ok := claudeModelMap[info.UpstreamModelName]; ok { + info.UpstreamModelName = v + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", + info.ApiVersion, + adc.ProjectID, + info.ApiVersion, + info.UpstreamModelName, + suffix, + ), nil + } else if a.RequestMode == RequestModeLlama { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", + info.ApiVersion, + adc.ProjectID, + info.ApiVersion, + ), nil + } + return "", errors.New("unsupported request mode") +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + accessToken, err := getAccessToken(a, info) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + if a.RequestMode == RequestModeClaude { + claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request) + if err != nil { + return nil, err + } + vertexClaudeReq := &VertexAIClaudeRequest{ + AnthropicVersion: anthropicVersion, + } + if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil { + return nil, errors.New("failed to copy claude request") + } + c.Set("request_model", request.Model) + return vertexClaudeReq, nil + } else if a.RequestMode == RequestModeGemini { + geminiRequest := gemini.CovertGemini2OpenAI(*request) + c.Set("request_model", request.Model) + return geminiRequest, nil + } else if a.RequestMode == RequestModeLlama { + return request, nil + } + return nil, errors.New("unsupported request mode") +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + switch a.RequestMode { + case RequestModeClaude: + err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + case RequestModeGemini: + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + case RequestModeLlama: + err, usage = openai.OaiStreamHandler(c, resp, info) + } + } else { + switch a.RequestMode { + case RequestModeClaude: + err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) + case RequestModeGemini: + err, usage = gemini.GeminiChatHandler(c, resp) + case RequestModeLlama: + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/vertex/constants.go b/relay/channel/vertex/constants.go new file mode 100644 index 0000000..6a31a86 --- /dev/null +++ b/relay/channel/vertex/constants.go @@ -0,0 +1,15 @@ +package vertex + +var ModelList = []string{ + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + "claude-3-5-sonnet-20240620", + + //"gemini-1.5-pro-latest", "gemini-1.5-flash-latest", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", + + "meta/llama3-405b-instruct-maas", +} + +var ChannelName = "vertex-ai" diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go new file mode 100644 index 0000000..b54a4aa --- /dev/null +++ b/relay/channel/vertex/dto.go @@ -0,0 +1,17 @@ +package vertex + +import "one-api/relay/channel/claude" + +type VertexAIClaudeRequest struct { + AnthropicVersion string `json:"anthropic_version"` + Messages []claude.ClaudeMessage `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []claude.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go new file mode 100644 index 0000000..884d09a --- /dev/null +++ b/relay/channel/vertex/service_account.go @@ -0,0 +1,122 @@ +package vertex + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "github.com/bytedance/gopkg/cache/asynccache" + "github.com/golang-jwt/jwt" + "net/http" + "net/url" + relaycommon "one-api/relay/common" + "strings" + + "fmt" + "time" +) + +type Credentials struct { + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + ClientID string `json:"client_id"` +} + +var Cache = asynccache.NewAsyncCache(asynccache.Options{ + RefreshDuration: time.Minute * 35, + EnableExpire: true, + ExpireDuration: time.Minute * 30, + Fetcher: func(key string) (interface{}, error) { + return nil, errors.New("not found") + }, +}) + +func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { + cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId) + val, err := Cache.Get(cacheKey) + if err == nil { + return val.(string), nil + } + + signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey) + if err != nil { + return "", fmt.Errorf("failed to create signed JWT: %w", err) + } + newToken, err := exchangeJwtForAccessToken(signedJWT) + if err != nil { + return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) + } + if err := Cache.SetDefault(cacheKey, newToken); err { + return newToken, nil + } + return newToken, nil +} + +func createSignedJWT(email, privateKeyPEM string) (string, error) { + + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "") + + block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----")) + if block == nil { + return "", fmt.Errorf("failed to parse PEM block containing the private key") + } + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return "", err + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return "", fmt.Errorf("not an RSA private key") + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": email, + "scope": "https://www.googleapis.com/auth/cloud-platform", + "aud": "https://www.googleapis.com/oauth2/v4/token", + "exp": now.Add(time.Minute * 30).Unix(), + "iat": now.Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedToken, err := token.SignedString(rsaPrivateKey) + if err != nil { + return "", err + } + + return signedToken, nil +} + +func exchangeJwtForAccessToken(signedJWT string) (string, error) { + + authURL := "https://www.googleapis.com/oauth2/v4/token" + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + data.Set("assertion", signedJWT) + + resp, err := http.PostForm(authURL, data) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + if accessToken, ok := result["access_token"].(string); ok { + return accessToken, nil + } + + return "", fmt.Errorf("failed to get access token: %v", result) +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3ed5ee3..db326aa 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -22,6 +22,7 @@ type RelayInfo struct { IsStream bool RelayMode int UpstreamModelName string + OriginModelName string RequestURLPath string ApiVersion string PromptTokens int @@ -57,6 +58,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { TokenUnlimited: tokenUnlimited, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), + OriginModelName: c.GetString("original_model"), + UpstreamModelName: c.GetString("original_model"), ApiType: apiType, ApiVersion: c.GetString("api_version"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), @@ -68,6 +71,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } + if info.ChannelType == common.ChannelTypeVertexAi { + info.ApiVersion = c.GetString("region") + } if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelCloudflare { diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 36a6b6b..98c6cc0 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -24,6 +24,7 @@ const ( APITypeJina APITypeCloudflare APITypeSiliconFlow + APITypeVertexAi APITypeDummy // this one is only for count, do not add any channel after this ) @@ -69,6 +70,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeCloudflare case common.ChannelTypeSiliconFlow: apiType = APITypeSiliconFlow + case common.ChannelTypeVertexAi: + apiType = APITypeVertexAi } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 96fb737..0164782 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -19,6 +19,7 @@ import ( "one-api/relay/channel/siliconflow" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" + "one-api/relay/channel/vertex" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" @@ -65,6 +66,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &cloudflare.Adaptor{} case constant.APITypeSiliconFlow: return &siliconflow.Adaptor{} + case constant.APITypeVertexAi: + return &vertex.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index ee53f7b..8a3ddbd 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -36,6 +36,7 @@ export const CHANNEL_OPTIONS = [ color: 'indigo', label: 'AWS Claude' }, + { key: 41, text: 'Vertex AI', value: 41, color: 'blue', label: 'Vertex AI' }, { key: 3, text: 'Azure OpenAI', diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 900fdf3..d9732e9 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -588,6 +588,24 @@ const EditChannel = (props) => { /> )} + {inputs.type === 41 && ( + <> +
+ 部署地区: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )} {inputs.type === 21 && ( <>
@@ -734,17 +752,47 @@ const EditChannel = (props) => { autoComplete='new-password' /> ) : ( - { - handleInputChange('key', value); - }} - value={inputs.key} - autoComplete='new-password' - /> + <> + {inputs.type === 41 ? ( +