feat: 支持vertex ai渠道多个部署地区

This commit is contained in:
CalciumIon 2024-08-28 18:43:40 +08:00
parent c41820541d
commit e60f200192
6 changed files with 75 additions and 17 deletions

View File

@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes) return string(bytes)
} }
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} { func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := json.Unmarshal([]byte(str), &m)
@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
return m return m
} }
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int { func String2Int(str string) int {
num, err := strconv.Atoi(str) num, err := strconv.Atoi(str)
if err != nil { if err != nil {

View File

@ -199,6 +199,25 @@ func AddChannel(c *gin.Context) {
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi { if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key} keys = []string{channel.Key}
} }
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))

View File

@ -62,6 +62,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err) return "", fmt.Errorf("failed to decode credentials file: %w", err)
} }
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc a.AccountCredentials = *adc
suffix := "" suffix := ""
if a.RequestMode == RequestModeGemini { if a.RequestMode == RequestModeGemini {
@ -72,9 +73,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
return fmt.Sprintf( return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
info.ApiVersion, region,
adc.ProjectID, adc.ProjectID,
info.ApiVersion, region,
info.UpstreamModelName, info.UpstreamModelName,
suffix, suffix,
), nil ), nil
@ -89,18 +90,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
return fmt.Sprintf( return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
info.ApiVersion, region,
adc.ProjectID, adc.ProjectID,
info.ApiVersion, region,
info.UpstreamModelName, info.UpstreamModelName,
suffix, suffix,
), nil ), nil
} else if a.RequestMode == RequestModeLlama { } else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf( return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
info.ApiVersion, region,
adc.ProjectID, adc.ProjectID,
info.ApiVersion, region,
), nil ), nil
} }
return "", errors.New("unsupported request mode") return "", errors.New("unsupported request mode")

View File

@ -0,0 +1,16 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@ -83,7 +83,7 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
"iss": email, "iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform", "scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token", "aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 30).Unix(), "exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(), "iat": now.Unix(),
} }

View File

@ -37,6 +37,11 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500', 400: '500',
}; };
const REGION_EXAMPLE = {
"default": "us-central1",
"claude-3-5-sonnet-20240620": "europe-west1"
}
const fetchButtonTips = "1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出" const fetchButtonTips = "1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出"
function type2secretPrompt(type) { function type2secretPrompt(type) {
@ -593,17 +598,37 @@ const EditChannel = (props) => {
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>部署地区</Typography.Text> <Typography.Text strong>部署地区</Typography.Text>
</div> </div>
<Input <TextArea
name='other' name='other'
placeholder={ placeholder={
'请输入部署地区例如us-central1' '请输入部署地区例如us-central1\n支持使用模型映射格式\n' +
'{\n' +
' "default": "us-central1",\n' +
' "claude-3-5-sonnet-20240620": "europe-west1"\n' +
'}'
} }
autosize={{ minRows: 2 }}
onChange={(value) => { onChange={(value) => {
handleInputChange('other', value); handleInputChange('other', value);
}} }}
value={inputs.other} value={inputs.other}
autoComplete='new-password' autoComplete='new-password'
/> />
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
}}
onClick={() => {
handleInputChange(
'other',
JSON.stringify(REGION_EXAMPLE, null, 2),
);
}}
>
填入模板
</Typography.Text>
</> </>
)} )}
{inputs.type === 21 && ( {inputs.type === 21 && (