import PropTypes from 'prop-types';
import { useState, useEffect } from 'react';
import { CHANNEL_OPTIONS } from 'constants/ChannelConstants';
import { useTheme } from '@mui/material/styles';
import { API } from 'utils/api';
import { showError, showSuccess, trims } from 'utils/common';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
TextField,
Button,
Divider,
Select,
MenuItem,
FormControl,
InputLabel,
OutlinedInput,
ButtonGroup,
Container,
Autocomplete,
FormHelperText,
Checkbox,
Switch,
FormControlLabel,
Typography,
Tooltip
} from '@mui/material';
import LoadingButton from '@mui/lab/LoadingButton';
import { Formik } from 'formik';
import * as Yup from 'yup';
import { defaultConfig, typeConfig } from '../type/Config'; //typeConfig
import { createFilterOptions } from '@mui/material/Autocomplete';
import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank';
import CheckBoxIcon from '@mui/icons-material/CheckBox';
const pluginList = require('../type/Plugin.json');
const icon = ;
const checkedIcon = ;
const filter = createFilterOptions();
const validationSchema = Yup.object().shape({
is_edit: Yup.boolean(),
name: Yup.string().required('名称 不能为空'),
type: Yup.number().required('渠道 不能为空'),
key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }),
other: Yup.string(),
proxy: Yup.string(),
test_model: Yup.string(),
models: Yup.array().min(1, '模型 不能为空'),
groups: Yup.array().min(1, '用户组 不能为空'),
base_url: Yup.string().when('type', {
is: (value) => [3, 8].includes(value),
then: Yup.string().required('渠道API地址 不能为空'), // base_url 是必需的
otherwise: Yup.string() // 在其他情况下,base_url 可以是任意字符串
}),
model_mapping: Yup.string().test('is-json', '必须是有效的JSON字符串', function (value) {
try {
if (value === '' || value === null || value === undefined) {
return true;
}
const parsedValue = JSON.parse(value);
if (typeof parsedValue === 'object') {
return true;
}
} catch (e) {
return false;
}
return false;
})
});
const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
const theme = useTheme();
// const [loading, setLoading] = useState(false);
const [initialInput, setInitialInput] = useState(defaultConfig.input);
const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); //
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
const [modelOptions, setModelOptions] = useState([]);
const [batchAdd, setBatchAdd] = useState(false);
const [providerModelsLoad, setProviderModelsLoad] = useState(false);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
setInputLabel({ ...defaultConfig.inputLabel, ...typeConfig[typeValue].inputLabel });
} else {
setInputLabel(defaultConfig.inputLabel);
}
if (typeConfig[typeValue]?.prompt) {
setInputPrompt({ ...defaultConfig.prompt, ...typeConfig[typeValue].prompt });
} else {
setInputPrompt(defaultConfig.prompt);
}
return typeConfig[typeValue]?.input;
};
const handleTypeChange = (setFieldValue, typeValue, values) => {
// 处理插件事务
if (pluginList[typeValue]) {
const newPluginValues = {};
const pluginConfig = pluginList[typeValue];
for (const pluginName in pluginConfig) {
const plugin = pluginConfig[pluginName];
const oldValve = values['plugin'] ? values['plugin'][pluginName] || {} : {};
newPluginValues[pluginName] = {};
for (const paramName in plugin.params) {
const param = plugin.params[paramName];
newPluginValues[pluginName][paramName] = oldValve[paramName] || (param.type === 'bool' ? false : '');
}
}
setFieldValue('plugin', newPluginValues);
}
const newInput = initChannel(typeValue);
if (newInput) {
Object.keys(newInput).forEach((key) => {
if (
(!Array.isArray(values[key]) && values[key] !== null && values[key] !== undefined && values[key] !== '') ||
(Array.isArray(values[key]) && values[key].length > 0)
) {
return;
}
if (key === 'models') {
setFieldValue(key, initialModel(newInput[key]));
return;
}
setFieldValue(key, newInput[key]);
});
}
};
const basicModels = (channelType) => {
let modelGroup = typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup;
// 循环 modelOptions,找到 modelGroup 对应的模型
let modelList = [];
modelOptions.forEach((model) => {
if (model.group === modelGroup) {
modelList.push(model);
}
});
return modelList;
};
const getProviderModels = async (values, setFieldValue) => {
setProviderModelsLoad(true);
try {
const res = await API.post(`/api/channel/provider_models_list`, { ...values, models: '' });
const { success, message, data } = res.data;
if (success && data) {
let uniqueModels = Array.from(new Set(data));
let modelList = uniqueModels.map((model) => {
return {
id: model,
group: '自定义:点击或回车输入'
};
});
setFieldValue('models', modelList);
} else {
showError(message || '获取模型列表失败');
}
} catch (error) {
showError(error.message);
}
setProviderModelsLoad(false);
};
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
const { data } = res.data;
// 先对data排序
data.sort((a, b) => {
const ownedByComparison = a.owned_by.localeCompare(b.owned_by);
if (ownedByComparison === 0) {
return a.id.localeCompare(b.id);
}
return ownedByComparison;
});
setModelOptions(
data.map((model) => {
return {
id: model.id,
group: model.owned_by
};
})
);
} catch (error) {
showError(error.message);
}
};
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
values = trims(values);
if (values.base_url && values.base_url.endsWith('/')) {
values.base_url = values.base_url.slice(0, values.base_url.length - 1);
}
if (values.type === 3 && values.other === '') {
values.other = '2023-09-01-preview';
}
if (values.type === 18 && values.other === '') {
values.other = 'v2.1';
}
let res;
const modelsStr = values.models.map((model) => model.id).join(',');
values.group = values.groups.join(',');
try {
if (channelId) {
res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId), models: modelsStr });
} else {
res = await API.post(`/api/channel/`, { ...values, models: modelsStr });
}
const { success, message } = res.data;
if (success) {
if (channelId) {
showSuccess('渠道更新成功!');
} else {
showSuccess('渠道创建成功!');
}
setSubmitting(false);
setStatus({ success: true });
onOk(true);
return;
} else {
setStatus({ success: false });
showError(message);
setErrors({ submit: message });
}
} catch (error) {
setStatus({ success: false });
showError(error.message);
setErrors({ submit: error.message });
return;
}
};
function initialModel(channelModel) {
if (!channelModel) {
return [];
}
// 如果 channelModel 是一个字符串
if (typeof channelModel === 'string') {
channelModel = channelModel.split(',');
}
let modelList = channelModel.map((model) => {
const modelOption = modelOptions.find((option) => option.id === model);
if (modelOption) {
return modelOption;
}
return { id: model, group: '自定义:点击或回车输入' };
});
return modelList;
}
const loadChannel = async () => {
try {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
if (success) {
if (data.models === '') {
data.models = [];
} else {
data.models = initialModel(data.models);
}
if (data.group === '') {
data.groups = [];
} else {
data.groups = data.group.split(',');
}
if (data.model_mapping !== '') {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
data.base_url = data.base_url ?? '';
data.is_edit = true;
if (data.plugin === null) {
data.plugin = {};
}
initChannel(data.type);
setInitialInput(data);
} else {
showError(message);
}
} catch (error) {
return;
}
};
useEffect(() => {
fetchModels().then();
}, []);
useEffect(() => {
setBatchAdd(false);
if (channelId) {
loadChannel().then();
} else {
initChannel(1);
setInitialInput({ ...defaultConfig.input, is_edit: false });
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [channelId]);
return (
);
};
export default EditModal;
EditModal.propTypes = {
open: PropTypes.bool,
channelId: PropTypes.number,
onCancel: PropTypes.func,
onOk: PropTypes.func,
groupOptions: PropTypes.array
};