feat: Optimize model list.

This commit is contained in:
MartialBE
2024-01-01 14:36:58 +08:00
committed by Buer
parent bf5ba315ee
commit 7ef4a7db59
9 changed files with 327 additions and 138 deletions

View File

@@ -68,7 +68,6 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
const [groupOptions, setGroupOptions] = useState([]);
const [modelOptions, setModelOptions] = useState([]);
const [basicModels, setBasicModels] = useState([]);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
@@ -96,11 +95,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
) {
return;
}
if (key === 'models') {
setFieldValue(key, initialModel(newInput[key]));
return;
}
setFieldValue(key, newInput[key]);
});
}
};
const basicModels = (channelType) => {
let modelGroup = typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup;
// 循环 modelOptions找到 modelGroup 对应的模型
let modelList = [];
modelOptions.forEach((model) => {
if (model.group === modelGroup) {
modelList.push(model);
}
});
return modelList;
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
@@ -113,13 +129,13 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
setModelOptions(res.data.data.map((model) => model.id));
setBasicModels(
res.data.data
.filter((model) => {
return model.id.startsWith('gpt-3') || model.id.startsWith('gpt-4');
})
.map((model) => model.id)
setModelOptions(
res.data.data.map((model) => {
return {
id: model.id,
group: model.owned_by
};
})
);
} catch (error) {
showError(error.message);
@@ -138,12 +154,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
values.other = 'v2.1';
}
let res;
values.models = values.models.join(',');
const modelsStr = values.models.map((model) => model.id).join(',');
values.group = values.groups.join(',');
if (channelId) {
res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId) });
res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId), models: modelsStr });
} else {
res = await API.post(`/api/channel/`, values);
res = await API.post(`/api/channel/`, { ...values, models: modelsStr });
}
const { success, message } = res.data;
if (success) {
@@ -157,11 +173,30 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
onOk(true);
} else {
setStatus({ success: false });
// showError(message);
showError(message);
setErrors({ submit: message });
}
};
function initialModel(channelModel) {
if (!channelModel) {
return [];
}
// 如果 channelModel 是一个字符串
if (typeof channelModel === 'string') {
channelModel = channelModel.split(',');
}
let modelList = channelModel.map((model) => {
const modelOption = modelOptions.find((option) => option.id === model);
if (modelOption) {
return modelOption;
}
return { id: model, group: '自定义:点击或回车输入' };
});
return modelList;
}
const loadChannel = async () => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
@@ -169,7 +204,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
if (data.models === '') {
data.models = [];
} else {
data.models = data.models.split(',');
data.models = initialModel(data.models);
}
if (data.group === '') {
data.groups = [];
@@ -348,12 +383,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
freeSolo
id="channel-models-label"
options={modelOptions}
value={Array.isArray(values.models) ? values.models : values.models.split(',')}
value={values.models}
onChange={(e, value) => {
const event = {
target: {
name: 'models',
value: value
value: value.map((item) => (typeof item === 'string' ? { id: item, group: '自定义:点击或回车输入' } : item))
}
};
handleChange(event);
@@ -361,12 +396,25 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
onBlur={handleBlur}
filterSelectedOptions
renderInput={(params) => <TextField {...params} name="models" error={Boolean(errors.models)} label={inputLabel.models} />}
groupBy={(option) => option.group}
getOptionLabel={(option) => {
if (typeof option === 'string') {
return option;
}
if (option.inputValue) {
return option.inputValue;
}
return option.id;
}}
filterOptions={(options, params) => {
const filtered = filter(options, params);
const { inputValue } = params;
const isExisting = options.some((option) => inputValue === option);
const isExisting = options.some((option) => inputValue === option.id);
if (inputValue !== '' && !isExisting) {
filtered.push(inputValue);
filtered.push({
id: inputValue,
group: '自定义:点击或回车输入'
});
}
return filtered;
}}
@@ -387,10 +435,10 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
<ButtonGroup variant="outlined" aria-label="small outlined primary button group">
<Button
onClick={() => {
setFieldValue('models', basicModels);
setFieldValue('models', basicModels(values.type));
}}
>
填入基础模型
填入渠道支持模型
</Button>
<Button
onClick={() => {

View File

@@ -41,13 +41,14 @@ const NameLabel = ({ name, models }) => {
}
placement="top"
>
{name}
<span>{name}</span>
</Tooltip>
);
};
NameLabel.propTypes = {
group: PropTypes.string
name: PropTypes.string,
models: PropTypes.string
};
export default NameLabel;

View File

@@ -35,7 +35,8 @@ const defaultConfig = {
model_mapping:
'请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}',
groups: '请选择该渠道所支持的用户组'
}
},
modelGroup: 'OpenAI'
};
const typeConfig = {
@@ -53,13 +54,15 @@ const typeConfig = {
input: {
models: ['PaLM-2'],
test_model: 'PaLM-2'
}
},
modelGroup: 'Google PaLM'
},
14: {
input: {
models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'],
test_model: 'claude-2'
}
},
modelGroup: 'Anthropic'
},
15: {
input: {
@@ -68,13 +71,15 @@ const typeConfig = {
},
prompt: {
key: '按照如下格式输入APIKey|SecretKey'
}
},
modelGroup: 'Baidu'
},
16: {
input: {
models: ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'],
test_model: 'chatglm_lite'
}
},
modelGroup: 'Zhipu'
},
17: {
inputLabel: {
@@ -96,7 +101,8 @@ const typeConfig = {
},
prompt: {
other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'
}
},
modelGroup: 'Ali'
},
18: {
inputLabel: {
@@ -108,13 +114,15 @@ const typeConfig = {
prompt: {
key: '按照如下格式输入APPID|APISecret|APIKey',
other: '请输入版本号例如v3.1'
}
},
modelGroup: 'Xunfei'
},
19: {
input: {
models: ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'],
test_model: '360GPT_S2_V9'
}
},
modelGroup: '360'
},
22: {
prompt: {
@@ -128,7 +136,8 @@ const typeConfig = {
},
prompt: {
key: '按照如下格式输入AppId|SecretId|SecretKey'
}
},
modelGroup: 'Tencent'
},
25: {
inputLabel: {
@@ -140,13 +149,15 @@ const typeConfig = {
},
prompt: {
other: '请输入版本号例如v1'
}
},
modelGroup: 'Google Gemini'
},
26: {
input: {
models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan2-53B', 'Baichuan-Text-Embedding'],
test_model: 'Baichuan2-Turbo'
}
},
modelGroup: 'Baichuan'
},
24: {
input: {

View File

@@ -0,0 +1,69 @@
import { useState, useEffect } from 'react';
import SubCard from 'ui-component/cards/SubCard';
// import { gridSpacing } from 'store/constant';
import { API } from 'utils/api';
import { showError, showSuccess } from 'utils/common';
import { Typography, Accordion, AccordionSummary, AccordionDetails, Box, Stack } from '@mui/material';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import Label from 'ui-component/Label';
const SupportModels = () => {
const [modelList, setModelList] = useState([]);
const fetchModels = async () => {
try {
let res = await API.get(`/api/user/models`);
if (res === undefined) {
return;
}
// 对 res.data.data 里面的 owned_by 进行分组
let modelGroup = {};
res.data.data.forEach((model) => {
if (modelGroup[model.owned_by] === undefined) {
modelGroup[model.owned_by] = [];
}
modelGroup[model.owned_by].push(model.id);
});
setModelList(modelGroup);
} catch (error) {
showError(error.message);
}
};
useEffect(() => {
fetchModels();
}, []);
return (
<Accordion key="support_models" sx={{ borderRadius: '12px' }}>
<AccordionSummary aria-controls="support_models" expandIcon={<ExpandMoreIcon />}>
<Typography variant="subtitle1">当前可用模型</Typography>
</AccordionSummary>
<AccordionDetails>
<Stack spacing={1}>
{Object.entries(modelList).map(([title, models]) => (
<SubCard key={title} title={title}>
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: '10px' }}>
{models.map((model) => (
<Label
variant="outlined"
color="primary"
key={model}
onClick={() => {
navigator.clipboard.writeText(model);
showSuccess('复制模型名称成功!');
}}
>
{model}
</Label>
))}
</Box>
</SubCard>
))}
</Stack>
</AccordionDetails>
</Accordion>
);
};
export default SupportModels;

View File

@@ -3,6 +3,7 @@ import { Grid, Typography } from '@mui/material';
import { gridSpacing } from 'store/constant';
import StatisticalLineChartCard from './component/StatisticalLineChartCard';
import StatisticalBarChart from './component/StatisticalBarChart';
import SupportModels from './component/SupportModels';
import { generateChartOptions, getLastSevenDays } from 'utils/chart';
import { API } from 'utils/api';
import { showError, calculateQuota, renderNumber } from 'utils/common';
@@ -50,6 +51,9 @@ const Dashboard = () => {
return (
<Grid container spacing={gridSpacing}>
<Grid item xs={12}>
<SupportModels />
</Grid>
<Grid item xs={12}>
<Grid container spacing={gridSpacing}>
<Grid item lg={4} xs={12}>
@@ -78,6 +82,7 @@ const Dashboard = () => {
</Grid>
</Grid>
</Grid>
<Grid item xs={12}>
<Grid container spacing={gridSpacing}>
<Grid item lg={8} xs={12}>