import PropTypes from 'prop-types';
import { useState, useEffect } from 'react';
import { CHANNEL_OPTIONS } from 'constants/ChannelConstants';
import { useTheme } from '@mui/material/styles';
import { API } from 'utils/api';
import { showError, showSuccess, getChannelModels } from 'utils/common';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
TextField,
Button,
Divider,
Select,
MenuItem,
FormControl,
InputLabel,
OutlinedInput,
ButtonGroup,
Container,
Autocomplete,
FormHelperText,
Switch,
Checkbox
} from '@mui/material';
import { Formik } from 'formik';
import * as Yup from 'yup';
import { defaultConfig, typeConfig } from '../type/Config'; //typeConfig
import { createFilterOptions } from '@mui/material/Autocomplete';
import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank';
import CheckBoxIcon from '@mui/icons-material/CheckBox';
const icon = ;
const checkedIcon = ;
const filter = createFilterOptions();
const validationSchema = Yup.object().shape({
is_edit: Yup.boolean(),
name: Yup.string().required('名称 不能为空'),
type: Yup.number().required('渠道 不能为空'),
key: Yup.string().when(['is_edit', 'type'], {
is: (is_edit, type) => !is_edit && type !== 33,
then: Yup.string().required('密钥 不能为空')
}),
other: Yup.string(),
models: Yup.array().min(1, '模型 不能为空'),
groups: Yup.array().min(1, '用户组 不能为空'),
base_url: Yup.string().when('type', {
is: (value) => [3, 8].includes(value),
then: Yup.string().required('渠道API地址 不能为空'), // base_url 是必需的
otherwise: Yup.string() // 在其他情况下,base_url 可以是任意字符串
}),
model_mapping: Yup.string().test('is-json', '必须是有效的JSON字符串', function (value) {
try {
if (value === '' || value === null || value === undefined) {
return true;
}
const parsedValue = JSON.parse(value);
if (typeof parsedValue === 'object') {
return true;
}
} catch (e) {
return false;
}
return false;
})
});
const EditModal = ({ open, channelId, onCancel, onOk }) => {
const theme = useTheme();
// const [loading, setLoading] = useState(false);
const [initialInput, setInitialInput] = useState(defaultConfig.input);
const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); //
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
const [groupOptions, setGroupOptions] = useState([]);
const [modelOptions, setModelOptions] = useState([]);
const [batchAdd, setBatchAdd] = useState(false);
const [basicModels, setBasicModels] = useState([]);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
setInputLabel({
...defaultConfig.inputLabel,
...typeConfig[typeValue].inputLabel
});
} else {
setInputLabel(defaultConfig.inputLabel);
}
if (typeConfig[typeValue]?.prompt) {
setInputPrompt({
...defaultConfig.prompt,
...typeConfig[typeValue].prompt
});
} else {
setInputPrompt(defaultConfig.prompt);
}
return typeConfig[typeValue]?.input;
};
const handleTypeChange = (setFieldValue, typeValue, values) => {
initChannel(typeValue);
let localModels = getChannelModels(typeValue);
setBasicModels(localModels);
if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length == 0) {
setFieldValue('models', initialModel(localModels));
}
setFieldValue('config', {});
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
setGroupOptions(res.data.data);
} catch (error) {
showError(error.message);
}
};
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
const { data } = res.data;
data.forEach((item) => {
if (!item.owned_by) {
item.owned_by = '未知';
}
});
// 先对data排序
data.sort((a, b) => {
const ownedByComparison = a.owned_by.localeCompare(b.owned_by);
if (ownedByComparison === 0) {
return a.id.localeCompare(b.id);
}
return ownedByComparison;
});
setModelOptions(
data.map((model) => {
return {
id: model.id,
group: model.owned_by
};
})
);
} catch (error) {
showError(error.message);
}
};
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
if (values.base_url && values.base_url.endsWith('/')) {
values.base_url = values.base_url.slice(0, values.base_url.length - 1);
}
if (values.type === 3 && values.other === '') {
values.other = '2023-09-01-preview';
}
if (values.type === 18 && values.other === '') {
values.other = 'v2.1';
}
if (values.key === '') {
if (values.config.ak && values.config.sk && values.config.region) {
values.key = `${values.config.ak}|${values.config.sk}|${values.config.region}`;
}
}
let res;
const modelsStr = values.models.map((model) => model.id).join(',');
const configStr = JSON.stringify(values.config);
values.group = values.groups.join(',');
if (channelId) {
res = await API.put(`/api/channel/`, {
...values,
id: parseInt(channelId),
models: modelsStr,
config: configStr
});
} else {
res = await API.post(`/api/channel/`, { ...values, models: modelsStr, config: configStr });
}
const { success, message } = res.data;
if (success) {
if (channelId) {
showSuccess('渠道更新成功!');
} else {
showSuccess('渠道创建成功!');
}
setSubmitting(false);
setStatus({ success: true });
onOk(true);
} else {
setStatus({ success: false });
showError(message);
setErrors({ submit: message });
}
};
function initialModel(channelModel) {
if (!channelModel) {
return [];
}
// 如果 channelModel 是一个字符串
if (typeof channelModel === 'string') {
channelModel = channelModel.split(',');
}
let modelList = channelModel.map((model) => {
const modelOption = modelOptions.find((option) => option.id === model);
if (modelOption) {
return modelOption;
}
return { id: model, group: '自定义:点击或回车输入' };
});
return modelList;
}
const loadChannel = async () => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
if (success) {
if (data.models === '') {
data.models = [];
} else {
data.models = initialModel(data.models);
}
if (data.group === '') {
data.groups = [];
} else {
data.groups = data.group.split(',');
}
if (data.model_mapping !== '') {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
if (data.config !== '') {
data.config = JSON.parse(data.config);
}
data.base_url = data.base_url ?? '';
data.is_edit = true;
initChannel(data.type);
setInitialInput(data);
} else {
showError(message);
}
};
useEffect(() => {
fetchGroups().then();
fetchModels().then();
}, []);
useEffect(() => {
setBatchAdd(false);
if (channelId) {
loadChannel().then();
} else {
initChannel(1);
setInitialInput({ ...defaultConfig.input, is_edit: false });
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [channelId]);
return (
);
};
export default EditModal;
EditModal.propTypes = {
open: PropTypes.bool,
channelId: PropTypes.number,
onCancel: PropTypes.func,
onOk: PropTypes.func
};