From e60f20019256f9406e6377f0fb2e38b236c9aedf Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 28 Aug 2024 18:43:40 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81vertex=20ai?= =?UTF-8?q?=E6=B8=A0=E9=81=93=E5=A4=9A=E4=B8=AA=E9=83=A8=E7=BD=B2=E5=9C=B0?= =?UTF-8?q?=E5=8C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 13 +++++------ controller/channel.go | 19 ++++++++++++++++ relay/channel/vertex/adaptor.go | 13 ++++++----- relay/channel/vertex/relay-vertex.go | 16 ++++++++++++++ relay/channel/vertex/service_account.go | 2 +- web/src/pages/Channel/EditChannel.js | 29 +++++++++++++++++++++++-- 6 files changed, 75 insertions(+), 17 deletions(-) create mode 100644 relay/channel/vertex/relay-vertex.go diff --git a/common/str.go b/common/str.go index d61adb1..d42fd83 100644 --- a/common/str.go +++ b/common/str.go @@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string { return string(bytes) } -func MapToJsonStrFloat(m map[string]float64) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - func StrToMap(str string) map[string]interface{} { m := make(map[string]interface{}) err := json.Unmarshal([]byte(str), &m) @@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} { return m } +func IsJsonStr(str string) bool { + var js map[string]interface{} + return json.Unmarshal([]byte(str), &js) == nil +} + func String2Int(str string) int { num, err := strconv.Atoi(str) if err != nil { diff --git a/controller/channel.go b/controller/channel.go index 65ef721..5d1af01 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -199,6 +199,25 @@ func AddChannel(c *gin.Context) { channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") if channel.Type == common.ChannelTypeVertexAi { + if channel.Other == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区不能为空", + }) + return + } else { + if common.IsJsonStr(channel.Other) { + // must have default + regionMap := common.StrToMap(channel.Other) + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "必须包含default字段", + }) + return + } + } + } keys = []string{channel.Key} } channels := make([]model.Channel, 0, len(keys)) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index e3b4782..4174d78 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -62,6 +62,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials file: %w", err) } + region := GetModelRegion(info.ApiVersion, info.OriginModelName) a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { @@ -72,9 +73,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - info.ApiVersion, + region, adc.ProjectID, - info.ApiVersion, + region, info.UpstreamModelName, suffix, ), nil @@ -89,18 +90,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - info.ApiVersion, + region, adc.ProjectID, - info.ApiVersion, + region, info.UpstreamModelName, suffix, ), nil } else if a.RequestMode == RequestModeLlama { return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", - info.ApiVersion, + region, adc.ProjectID, - info.ApiVersion, + region, ), nil } return "", errors.New("unsupported request mode") diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go new file mode 100644 index 0000000..d259632 --- /dev/null +++ b/relay/channel/vertex/relay-vertex.go @@ -0,0 +1,16 @@ +package vertex + +import "one-api/common" + +func GetModelRegion(other string, localModelName string) string { + // if other is json string + if common.IsJsonStr(other) { + m := common.StrToMap(other) + if m[localModelName] != nil { + return m[localModelName].(string) + } else { + return m["default"].(string) + } + } + return other +} diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 884d09a..cc64080 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -83,7 +83,7 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) { "iss": email, "scope": "https://www.googleapis.com/auth/cloud-platform", "aud": "https://www.googleapis.com/oauth2/v4/token", - "exp": now.Add(time.Minute * 30).Unix(), + "exp": now.Add(time.Minute * 35).Unix(), "iat": now.Unix(), } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index d9732e9..dec3769 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -37,6 +37,11 @@ const STATUS_CODE_MAPPING_EXAMPLE = { 400: '500', }; +const REGION_EXAMPLE = { + "default": "us-central1", + "claude-3-5-sonnet-20240620": "europe-west1" +} + const fetchButtonTips = "1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出" function type2secretPrompt(type) { @@ -593,17 +598,37 @@ const EditChannel = (props) => {