feat: 支持vertex ai渠道多个部署地区

This commit is contained in:
CalciumIon 2024-08-28 18:43:40 +08:00
parent c41820541d
commit e60f200192
6 changed files with 75 additions and 17 deletions

View File

@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes)
}
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
return m
}
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {

View File

@ -199,6 +199,25 @@ func AddChannel(c *gin.Context) {
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys))

View File

@ -62,6 +62,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
@ -72,9 +73,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
info.ApiVersion,
region,
adc.ProjectID,
info.ApiVersion,
region,
info.UpstreamModelName,
suffix,
), nil
@ -89,18 +90,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
info.ApiVersion,
region,
adc.ProjectID,
info.ApiVersion,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
info.ApiVersion,
region,
adc.ProjectID,
info.ApiVersion,
region,
), nil
}
return "", errors.New("unsupported request mode")

View File

@ -0,0 +1,16 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@ -83,7 +83,7 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 30).Unix(),
"exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(),
}

View File

@ -37,6 +37,11 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500',
};
const REGION_EXAMPLE = {
"default": "us-central1",
"claude-3-5-sonnet-20240620": "europe-west1"
}
const fetchButtonTips = "1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出"
function type2secretPrompt(type) {
@ -593,17 +598,37 @@ const EditChannel = (props) => {
<div style={{ marginTop: 10 }}>
<Typography.Text strong>部署地区</Typography.Text>
</div>
<Input
<TextArea
name='other'
placeholder={
'请输入部署地区例如us-central1'
'请输入部署地区例如us-central1\n支持使用模型映射格式\n' +
'{\n' +
' "default": "us-central1",\n' +
' "claude-3-5-sonnet-20240620": "europe-west1"\n' +
'}'
}
autosize={{ minRows: 2 }}
onChange={(value) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
/>
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
}}
onClick={() => {
handleInputChange(
'other',
JSON.stringify(REGION_EXAMPLE, null, 2),
);
}}
>
填入模板
</Typography.Text>
</>
)}
{inputs.type === 21 && (