diff --git a/README.md b/README.md index 4a7f4c6..206b71b 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ 8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md) 10. Dify +11. Vertex AI,目前兼容Claude,Gemini,Llama3.1 您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 diff --git a/common/constants.go b/common/constants.go index 930a3b0..6c2bb4b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -237,6 +237,7 @@ const ( ChannelTypeJina = 38 ChannelCloudflare = 39 ChannelTypeSiliconFlow = 40 + ChannelTypeVertexAi = 41 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -284,4 +285,5 @@ var ChannelBaseURLs = []string{ "https://api.jina.ai", //38 "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 + "", //41 } diff --git a/common/email.go b/common/email.go index 905f385..72e8979 100644 --- a/common/email.go +++ b/common/email.go @@ -10,7 +10,7 @@ import ( ) func generateMessageID() string { - domain := strings.Split(SMTPFrom, "@")[1] + domain := strings.Split(SMTPAccount, "@")[1] return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain) } @@ -18,6 +18,9 @@ func SendEmail(subject string, receiver string, content string) error { if SMTPFrom == "" { // for compatibility SMTPFrom = SMTPAccount } + if SMTPServer == "" && SMTPAccount == "" { + return fmt.Errorf("SMTP 服务器未配置") + } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ diff --git a/common/model-ratio.go b/common/model-ratio.go index 507c69d..4f7275a 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -117,6 +117,13 @@ var defaultModelRatio = map[string]float64{ "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-3-turbo": 0.3572, + "glm-4-plus": 0.05 * RMB, + "glm-4-0520": 0.1 * RMB, + "glm-4-air": 0.001 * RMB, + "glm-4-airx": 0.01 * RMB, + "glm-4-long": 0.001 * RMB, + "glm-4-flash": 0, + "glm-4v-plus": 0.01 * RMB, "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens @@ -135,26 +142,28 @@ var defaultModelRatio = map[string]float64{ "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 // https://platform.lingyiwanwu.com/docs#-计费单元 // 已经按照 7.2 来换算美元价格 - "yi-34b-chat-0205": 0.18, - "yi-34b-chat-200k": 0.864, - "yi-vl-plus": 0.432, - "yi-large": 20.0 / 1000 * RMB, - "yi-medium": 2.5 / 1000 * RMB, - "yi-vision": 6.0 / 1000 * RMB, - "yi-medium-200k": 12.0 / 1000 * RMB, - "yi-spark": 1.0 / 1000 * RMB, - "yi-large-rag": 25.0 / 1000 * RMB, - "yi-large-turbo": 12.0 / 1000 * RMB, - "yi-large-preview": 20.0 / 1000 * RMB, - "yi-large-rag-preview": 25.0 / 1000 * RMB, - "command": 0.5, - "command-nightly": 0.5, - "command-light": 0.5, - "command-light-nightly": 0.5, - "command-r": 0.25, - "command-r-plus ": 1.5, - "deepseek-chat": 0.07, - "deepseek-coder": 0.07, + "yi-34b-chat-0205": 0.18, + "yi-34b-chat-200k": 0.864, + "yi-vl-plus": 0.432, + "yi-large": 20.0 / 1000 * RMB, + "yi-medium": 2.5 / 1000 * RMB, + "yi-vision": 6.0 / 1000 * RMB, + "yi-medium-200k": 12.0 / 1000 * RMB, + "yi-spark": 1.0 / 1000 * RMB, + "yi-large-rag": 25.0 / 1000 * RMB, + "yi-large-turbo": 12.0 / 1000 * RMB, + "yi-large-preview": 20.0 / 1000 * RMB, + "yi-large-rag-preview": 25.0 / 1000 * RMB, + "command": 0.5, + "command-nightly": 0.5, + "command-light": 0.5, + "command-light-nightly": 0.5, + "command-r": 0.25, + "command-r-plus": 1.5, + "command-r-08-2024": 0.075, + "command-r-plus-08-2024": 1.25, + "deepseek-chat": 0.07, + "deepseek-coder": 0.07, // Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用 "llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-online": 0.2 / 1000 * USD, @@ -366,6 +375,10 @@ func GetCompletionRatio(name string) float64 { return 3 case "command-r-plus": return 5 + case "command-r-08-2024": + return 4 + case "command-r-plus-08-2024": + return 4 default: return 2 } diff --git a/common/str.go b/common/str.go index d61adb1..d42fd83 100644 --- a/common/str.go +++ b/common/str.go @@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string { return string(bytes) } -func MapToJsonStrFloat(m map[string]float64) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - func StrToMap(str string) map[string]interface{} { m := make(map[string]interface{}) err := json.Unmarshal([]byte(str), &m) @@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} { return m } +func IsJsonStr(str string) bool { + var js map[string]interface{} + return json.Unmarshal([]byte(str), &js) == nil +} + func String2Int(str string) int { num, err := strconv.Atoi(str) if err != nil { diff --git a/controller/channel.go b/controller/channel.go index d723ee6..7852595 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -198,6 +198,28 @@ func AddChannel(c *gin.Context) { } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") + if channel.Type == common.ChannelTypeVertexAi { + if channel.Other == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区不能为空", + }) + return + } else { + if common.IsJsonStr(channel.Other) { + // must have default + regionMap := common.StrToMap(channel.Other) + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return + } + } + } + keys = []string{channel.Key} + } channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { @@ -297,6 +319,27 @@ func UpdateChannel(c *gin.Context) { }) return } + if channel.Type == common.ChannelTypeVertexAi { + if channel.Other == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区不能为空", + }) + return + } else { + if common.IsJsonStr(channel.Other) { + // must have default + regionMap := common.StrToMap(channel.Other) + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return + } + } + } + } err = channel.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/go.mod b/go.mod index a88e235..a8ddec0 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,16 @@ module one-api // +heroku goVersion go1.18 -go 1.18 +go 1.21 + +toolchain go1.22.4 require ( github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 + github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -24,8 +27,9 @@ require ( github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/stripe/stripe-go/v76 v76.21.0 - golang.org/x/crypto v0.21.0 + golang.org/x/crypto v0.26.0 golang.org/x/image v0.15.0 + golang.org/x/net v0.28.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -38,9 +42,8 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect - github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect github.com/bytedance/sonic v1.9.1 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -51,6 +54,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect @@ -68,6 +72,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -75,10 +80,9 @@ require ( github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e4fad3c..90efe27 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= @@ -32,11 +32,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= -github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= @@ -57,6 +56,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -81,10 +81,9 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= @@ -128,8 +127,6 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8= -github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -144,16 +141,17 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= -github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -176,7 +174,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU= github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -197,19 +196,19 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -222,26 +221,27 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/middleware/distributor.go b/middleware/distributor.go index 1be3b31..3ca5b8f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -199,6 +199,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + case common.ChannelTypeVertexAi: + c.Set("region", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) case common.ChannelTypeGemini: diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 0544695..b9173af 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = claudeStreamHandler(c, resp, info, a.RequestMode) + err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = ClaudeHandler(c, resp, a.RequestMode, info) } return } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1a64938..529ee42 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -346,7 +346,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope return &fullTextResponse } -func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) var usage *dto.Usage usage = &dto.Usage{} @@ -428,7 +428,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return nil, usage } -func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -454,15 +454,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, model) + completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } usage := dto.Usage{} if requestMode == RequestModeCompletion { - usage.PromptTokens = promptTokens + usage.PromptTokens = info.PromptTokens usage.CompletionTokens = completionTokens - usage.TotalTokens = promptTokens + completionTokens + usage.TotalTokens = info.PromptTokens + completionTokens } else { usage.PromptTokens = claudeResponse.Usage.InputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens diff --git a/relay/channel/cohere/constant.go b/relay/channel/cohere/constant.go index 8f34e4f..734620a 100644 --- a/relay/channel/cohere/constant.go +++ b/relay/channel/cohere/constant.go @@ -1,7 +1,10 @@ package cohere var ModelList = []string{ - "command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly", + "command-r", "command-r-plus", + "command-r-08-2024", "command-r-plus-08-2024", + "c4ai-aya-23-35b", "c4ai-aya-23-8b", + "command-light", "command-light-nightly", "command", "command-nightly", "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0", } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4c4649f..07cdcfa 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -70,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = geminiChatStreamHandler(c, resp, info) + err, usage = GeminiChatStreamHandler(c, resp, info) } else { - err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = GeminiChatHandler(c, resp) } return } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 79dc44b..764d20b 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() @@ -279,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom return nil, usage } -func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 6a04d08..f296ed0 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil } return "", errors.New("invalid relay mode") } @@ -58,6 +58,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) + } else if info.RelayMode == constant.RelayModeEmbeddings { + err, usage = jinaEmbeddingHandler(c, resp) } return } diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go index 5fdd44f..6c339ae 100644 --- a/relay/channel/jina/relay-jina.go +++ b/relay/channel/jina/relay-jina.go @@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit _, err = c.Writer.Write(jsonResponse) return nil, &jinaResp.Usage } + +func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var jinaResp dto.OpenAIEmbeddingResponse + err = json.Unmarshal(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + jsonResponse, err := json.Marshal(jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &jinaResp.Usage +} diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 1614905..6906fca 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } @@ -59,14 +59,17 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.RelayMode == constant.RelayModeRerank { + switch info.RelayMode { + case constant.RelayModeRerank: err, usage = siliconflowRerankHandler(c, resp) - } else if info.RelayMode == constant.RelayModeChatCompletions { + case constant.RelayModeChatCompletions: if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } + case constant.RelayModeEmbeddings: + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go new file mode 100644 index 0000000..4174d78 --- /dev/null +++ b/relay/channel/vertex/adaptor.go @@ -0,0 +1,184 @@ +package vertex + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/claude" + "one-api/relay/channel/gemini" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + "strings" +) + +const ( + RequestModeClaude = 1 + RequestModeGemini = 2 + RequestModeLlama = 3 +) + +var claudeModelMap = map[string]string{ + "claude-3-sonnet-20240229": "claude-3-sonnet@20240229", + "claude-3-opus-20240229": "claude-3-opus@20240229", + "claude-3-haiku-20240307": "claude-3-haiku@20240307", + "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", +} + +const anthropicVersion = "vertex-2023-10-16" + +type Adaptor struct { + RequestMode int + AccountCredentials Credentials +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + if strings.HasPrefix(info.UpstreamModelName, "claude") { + a.RequestMode = RequestModeClaude + } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { + a.RequestMode = RequestModeGemini + } else if strings.Contains(info.UpstreamModelName, "llama") { + a.RequestMode = RequestModeLlama + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + adc := &Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials file: %w", err) + } + region := GetModelRegion(info.ApiVersion, info.OriginModelName) + a.AccountCredentials = *adc + suffix := "" + if a.RequestMode == RequestModeGemini { + if info.IsStream { + suffix = "streamGenerateContent?alt=sse" + } else { + suffix = "generateContent" + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + region, + adc.ProjectID, + region, + info.UpstreamModelName, + suffix, + ), nil + } else if a.RequestMode == RequestModeClaude { + if info.IsStream { + suffix = "streamRawPredict?alt=sse" + } else { + suffix = "rawPredict" + } + if v, ok := claudeModelMap[info.UpstreamModelName]; ok { + info.UpstreamModelName = v + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", + region, + adc.ProjectID, + region, + info.UpstreamModelName, + suffix, + ), nil + } else if a.RequestMode == RequestModeLlama { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", + region, + adc.ProjectID, + region, + ), nil + } + return "", errors.New("unsupported request mode") +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + accessToken, err := getAccessToken(a, info) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + if a.RequestMode == RequestModeClaude { + claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request) + if err != nil { + return nil, err + } + vertexClaudeReq := &VertexAIClaudeRequest{ + AnthropicVersion: anthropicVersion, + } + if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil { + return nil, errors.New("failed to copy claude request") + } + c.Set("request_model", request.Model) + return vertexClaudeReq, nil + } else if a.RequestMode == RequestModeGemini { + geminiRequest := gemini.CovertGemini2OpenAI(*request) + c.Set("request_model", request.Model) + return geminiRequest, nil + } else if a.RequestMode == RequestModeLlama { + return request, nil + } + return nil, errors.New("unsupported request mode") +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + switch a.RequestMode { + case RequestModeClaude: + err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + case RequestModeGemini: + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + case RequestModeLlama: + err, usage = openai.OaiStreamHandler(c, resp, info) + } + } else { + switch a.RequestMode { + case RequestModeClaude: + err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) + case RequestModeGemini: + err, usage = gemini.GeminiChatHandler(c, resp) + case RequestModeLlama: + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/vertex/constants.go b/relay/channel/vertex/constants.go new file mode 100644 index 0000000..6a31a86 --- /dev/null +++ b/relay/channel/vertex/constants.go @@ -0,0 +1,15 @@ +package vertex + +var ModelList = []string{ + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + "claude-3-5-sonnet-20240620", + + //"gemini-1.5-pro-latest", "gemini-1.5-flash-latest", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", + + "meta/llama3-405b-instruct-maas", +} + +var ChannelName = "vertex-ai" diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go new file mode 100644 index 0000000..b54a4aa --- /dev/null +++ b/relay/channel/vertex/dto.go @@ -0,0 +1,17 @@ +package vertex + +import "one-api/relay/channel/claude" + +type VertexAIClaudeRequest struct { + AnthropicVersion string `json:"anthropic_version"` + Messages []claude.ClaudeMessage `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []claude.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` +} diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go new file mode 100644 index 0000000..d259632 --- /dev/null +++ b/relay/channel/vertex/relay-vertex.go @@ -0,0 +1,16 @@ +package vertex + +import "one-api/common" + +func GetModelRegion(other string, localModelName string) string { + // if other is json string + if common.IsJsonStr(other) { + m := common.StrToMap(other) + if m[localModelName] != nil { + return m[localModelName].(string) + } else { + return m["default"].(string) + } + } + return other +} diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go new file mode 100644 index 0000000..cc64080 --- /dev/null +++ b/relay/channel/vertex/service_account.go @@ -0,0 +1,122 @@ +package vertex + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "github.com/bytedance/gopkg/cache/asynccache" + "github.com/golang-jwt/jwt" + "net/http" + "net/url" + relaycommon "one-api/relay/common" + "strings" + + "fmt" + "time" +) + +type Credentials struct { + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + ClientID string `json:"client_id"` +} + +var Cache = asynccache.NewAsyncCache(asynccache.Options{ + RefreshDuration: time.Minute * 35, + EnableExpire: true, + ExpireDuration: time.Minute * 30, + Fetcher: func(key string) (interface{}, error) { + return nil, errors.New("not found") + }, +}) + +func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { + cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId) + val, err := Cache.Get(cacheKey) + if err == nil { + return val.(string), nil + } + + signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey) + if err != nil { + return "", fmt.Errorf("failed to create signed JWT: %w", err) + } + newToken, err := exchangeJwtForAccessToken(signedJWT) + if err != nil { + return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) + } + if err := Cache.SetDefault(cacheKey, newToken); err { + return newToken, nil + } + return newToken, nil +} + +func createSignedJWT(email, privateKeyPEM string) (string, error) { + + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "") + privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "") + + block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----")) + if block == nil { + return "", fmt.Errorf("failed to parse PEM block containing the private key") + } + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return "", err + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return "", fmt.Errorf("not an RSA private key") + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": email, + "scope": "https://www.googleapis.com/auth/cloud-platform", + "aud": "https://www.googleapis.com/oauth2/v4/token", + "exp": now.Add(time.Minute * 35).Unix(), + "iat": now.Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedToken, err := token.SignedString(rsaPrivateKey) + if err != nil { + return "", err + } + + return signedToken, nil +} + +func exchangeJwtForAccessToken(signedJWT string) (string, error) { + + authURL := "https://www.googleapis.com/oauth2/v4/token" + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + data.Set("assertion", signedJWT) + + resp, err := http.PostForm(authURL, data) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + if accessToken, ok := result["access_token"].(string); ok { + return accessToken, nil + } + + return "", fmt.Errorf("failed to get access token: %v", result) +} diff --git a/relay/channel/zhipu_4v/constants.go b/relay/channel/zhipu_4v/constants.go index 3383eb3..816fa53 100644 --- a/relay/channel/zhipu_4v/constants.go +++ b/relay/channel/zhipu_4v/constants.go @@ -1,7 +1,7 @@ package zhipu_4v var ModelList = []string{ - "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", + "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus", } var ChannelName = "zhipu_4v" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3ed5ee3..db326aa 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -22,6 +22,7 @@ type RelayInfo struct { IsStream bool RelayMode int UpstreamModelName string + OriginModelName string RequestURLPath string ApiVersion string PromptTokens int @@ -57,6 +58,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { TokenUnlimited: tokenUnlimited, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), + OriginModelName: c.GetString("original_model"), + UpstreamModelName: c.GetString("original_model"), ApiType: apiType, ApiVersion: c.GetString("api_version"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), @@ -68,6 +71,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } + if info.ChannelType == common.ChannelTypeVertexAi { + info.ApiVersion = c.GetString("region") + } if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelCloudflare { diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 36a6b6b..98c6cc0 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -24,6 +24,7 @@ const ( APITypeJina APITypeCloudflare APITypeSiliconFlow + APITypeVertexAi APITypeDummy // this one is only for count, do not add any channel after this ) @@ -69,6 +70,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeCloudflare case common.ChannelTypeSiliconFlow: apiType = APITypeSiliconFlow + case common.ChannelTypeVertexAi: + apiType = APITypeVertexAi } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay-text.go b/relay/relay-text.go index 6f1ce30..f35f140 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) } case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeModerations: - if textRequest.Input == "" { + if textRequest.Input == "" || textRequest.Input == nil { return nil, errors.New("field input is required") } case relayconstant.RelayModeEdits: diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 96fb737..0164782 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -19,6 +19,7 @@ import ( "one-api/relay/channel/siliconflow" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" + "one-api/relay/channel/vertex" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" @@ -65,6 +66,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &cloudflare.Adaptor{} case constant.APITypeSiliconFlow: return &siliconflow.Adaptor{} + case constant.APITypeVertexAi: + return &vertex.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index bd5a8a9..620cab8 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -36,6 +36,7 @@ export const CHANNEL_OPTIONS = [ color: 'indigo', label: 'AWS Claude', }, + { key: 41, text: 'Vertex AI', value: 41, color: 'blue', label: 'Vertex AI' }, { key: 3, text: 'Azure OpenAI', diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index d3c70b6..b62930b 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -37,6 +37,11 @@ const STATUS_CODE_MAPPING_EXAMPLE = { 400: '500', }; +const REGION_EXAMPLE = { + default: 'us-central1', + 'claude-3-5-sonnet-20240620': 'europe-west1', +}; + const fetchButtonTips = '1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出'; @@ -584,6 +589,44 @@ const EditChannel = (props) => { /> )} + {inputs.type === 41 && ( + <> +
+ 部署地区: +
+