diff --git a/common/constants.go b/common/constants.go index 99c78ac..cb58391 100644 --- a/common/constants.go +++ b/common/constants.go @@ -208,6 +208,8 @@ const ( ChannelTypeLingYiWanWu = 31 ChannelTypeAws = 33 ChannelTypeCohere = 34 + + ChannelTypeDummy // this one is only for count, do not add any channel after this ) var ChannelBaseURLs = []string{ diff --git a/controller/model.go b/controller/model.go index 1e78b4c..4d8e5e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,14 +3,17 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" + "log" "net/http" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" - "one-api/relay/channel/moonshot" "one-api/relay/channel/lingyiwanwu" + "one-api/relay/channel/moonshot" + relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" ) @@ -43,6 +46,7 @@ type OpenAIModels struct { var openAIModels []OpenAIModels var openAIModelsMap map[string]OpenAIModels +var channelId2Models map[int][]string func init() { var permission []OpenAIModelPermission @@ -85,7 +89,7 @@ func init() { Id: modelName, Object: "model", Created: 1626777600, - OwnedBy: "360", + OwnedBy: ai360.ChannelName, Permission: permission, Root: modelName, Parent: nil, @@ -128,6 +132,18 @@ func init() { for _, model := range openAIModels { openAIModelsMap[model.Id] = model } + channelId2Models = make(map[int][]string) + for i := 1; i <= common.ChannelTypeDummy; i++ { + apiType := relayconstant.ChannelType2APIType(i) + if apiType == -1 || apiType == relayconstant.APITypeAIProxyLibrary { + continue + } + log.Println(apiType) + meta := &relaycommon.RelayInfo{ChannelType: i} + adaptor := relay.GetAdaptor(apiType) + adaptor.Init(meta, dto.GeneralOpenAIRequest{}) + channelId2Models[i] = adaptor.GetModelList() + } } func ListModels(c *gin.Context) { @@ -148,15 +164,22 @@ func ListModels(c *gin.Context) { } } c.JSON(200, gin.H{ - "object": "list", - "data": userOpenAiModels, + "success": true, + "data": userOpenAiModels, }) } func ChannelListModels(c *gin.Context) { c.JSON(200, gin.H{ - "object": "list", - "data": openAIModels, + "success": true, + "data": openAIModels, + }) +} + +func DashboardListModels(c *gin.Context) { + c.JSON(200, gin.H{ + "success": true, + "data": channelId2Models, }) } diff --git a/relay/channel/ai360/constants.go b/relay/channel/ai360/constants.go index cfc3cb2..82698fa 100644 --- a/relay/channel/ai360/constants.go +++ b/relay/channel/ai360/constants.go @@ -6,3 +6,5 @@ var ModelList = []string{ "embedding_s1_v1", "semantic_similarity_s1_v1", } + +var ChannelName = "ai360" diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go index 970e977..682626a 100644 --- a/relay/channel/ollama/constants.go +++ b/relay/channel/ollama/constants.go @@ -1,5 +1,7 @@ package ollama -var ModelList []string +var ModelList = []string{ + "llama3-7b", +} var ChannelName = "ollama" diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index 91f4e51..8c560c7 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -6,7 +6,7 @@ var ModelList = []string{ "gpt-3.5-turbo-instruct", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-4-turbo-preview", + "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 7f11ae2..1bc8b47 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -25,8 +25,18 @@ const ( ) func ChannelType2APIType(channelType int) int { - apiType := APITypeOpenAI + apiType := -1 switch channelType { + case common.ChannelTypeOpenAI: + apiType = APITypeOpenAI + case common.ChannelTypeAzure: + apiType = APITypeOpenAI + case common.ChannelTypeMoonshot: + apiType = APITypeOpenAI + case common.ChannelTypeLingYiWanWu: + apiType = APITypeOpenAI + case common.ChannelType360: + apiType = APITypeOpenAI case common.ChannelTypeAnthropic: apiType = APITypeAnthropic case common.ChannelTypeBaidu: diff --git a/router/api-router.go b/router/api-router.go index 8547454..8c0ae30 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.Use(middleware.GlobalAPIRateLimit()) { apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index ebfcf4d..452309c 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -31,6 +31,7 @@ import { } from '@douyinfe/semi-ui'; import EditChannel from '../pages/Channel/EditChannel'; import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; +import { loadChannelModels } from './utils.js'; function renderTimestamp(timestamp) { return <>{timestamp2string(timestamp)}; @@ -354,27 +355,29 @@ const ChannelsTable = () => { }; const copySelectedChannel = async (id) => { - const channelToCopy = channels.find(channel => String(channel.id) === String(id)); - console.log(channelToCopy) + const channelToCopy = channels.find( + (channel) => String(channel.id) === String(id), + ); + console.log(channelToCopy); channelToCopy.name += '_复制'; channelToCopy.created_time = null; channelToCopy.balance = 0; channelToCopy.used_quota = 0; if (!channelToCopy) { - showError("渠道未找到,请刷新页面后重试。"); - return; + showError('渠道未找到,请刷新页面后重试。'); + return; } try { - const newChannel = {...channelToCopy, id: undefined}; - const response = await API.post('/api/channel/', newChannel); - if (response.data.success) { - showSuccess("渠道复制成功"); - await refresh(); - } else { - showError(response.data.message); - } + const newChannel = { ...channelToCopy, id: undefined }; + const response = await API.post('/api/channel/', newChannel); + if (response.data.success) { + showSuccess('渠道复制成功'); + await refresh(); + } else { + showError(response.data.message); + } } catch (error) { - showError("渠道复制失败: " + error.message); + showError('渠道复制失败: ' + error.message); } }; @@ -395,6 +398,7 @@ const ChannelsTable = () => { showError(reason); }); fetchGroups().then(); + loadChannelModels().then(); }, []); const manageChannel = async (id, action, record, value) => { diff --git a/web/src/components/utils.js b/web/src/components/utils.js index 59e3a01..1f0ee30 100644 --- a/web/src/components/utils.js +++ b/web/src/components/utils.js @@ -18,3 +18,32 @@ export async function onGitHubOAuthClicked(github_client_id) { `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`, ); } + +let channelModels = undefined; +export async function loadChannelModels() { + const res = await API.get('/api/models'); + const { success, data } = res.data; + if (!success) { + return; + } + channelModels = data; + localStorage.setItem('channel_models', JSON.stringify(data)); +} + +export function getChannelModels(type) { + if (channelModels !== undefined && type in channelModels) { + if (!channelModels[type]) { + return []; + } + return channelModels[type]; + } + let models = localStorage.getItem('channel_models'); + if (!models) { + return []; + } + channelModels = JSON.parse(models); + if (type in channelModels) { + return channelModels[type]; + } + return []; +} diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index c6f0899..94b2fc6 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -86,13 +86,13 @@ export const CHANNEL_OPTIONS = [ label: '智谱 ChatGLM', }, { - key: 16, + key: 26, text: '智谱 GLM-4V', value: 26, color: 'purple', label: '智谱 GLM-4V', }, - { key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, + { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index cc8707d..79e54b4 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -23,6 +23,7 @@ import { Banner, } from '@douyinfe/semi-ui'; import { Divider } from 'semantic-ui-react'; +import { getChannelModels, loadChannelModels } from '../../components/utils.js'; const MODEL_MAPPING_EXAMPLE = { 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', @@ -87,97 +88,9 @@ const EditChannel = (props) => { const [customModel, setCustomModel] = useState(''); const handleInputChange = (name, value) => { setInputs((inputs) => ({ ...inputs, [name]: value })); - if (name === 'type' && inputs.models.length === 0) { + if (name === 'type') { let localModels = []; switch (value) { - case 33: - case 14: - localModels = [ - 'claude-instant-1.2', - 'claude-2', - 'claude-2.0', - 'claude-2.1', - 'claude-3-opus-20240229', - 'claude-3-sonnet-20240229', - 'claude-3-haiku-20240307', - ]; - break; - case 11: - localModels = ['PaLM-2']; - break; - case 15: - localModels = [ - 'ERNIE-Bot', - 'ERNIE-Bot-turbo', - 'ERNIE-Bot-4', - 'Embedding-V1', - ]; - break; - case 17: - localModels = [ - 'qwen-turbo', - 'qwen-plus', - 'qwen-max', - 'qwen-max-longcontext', - 'text-embedding-v1', - ]; - break; - case 16: - localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; - break; - case 18: - localModels = [ - 'SparkDesk', - 'SparkDesk-v1.1', - 'SparkDesk-v2.1', - 'SparkDesk-v3.1', - 'SparkDesk-v3.5', - ]; - break; - case 19: - localModels = [ - '360GPT_S2_V9', - 'embedding-bert-512-v1', - 'embedding_s1_v1', - 'semantic_similarity_s1_v1', - ]; - break; - case 23: - localModels = ['hunyuan']; - break; - case 24: - localModels = [ - 'gemini-1.0-pro-001', - 'gemini-1.0-pro-vision-001', - 'gemini-1.5-pro', - 'gemini-1.5-pro-latest', - 'gemini-pro', - 'gemini-pro-vision', - ]; - break; - case 34: - localModels = [ - 'command-r', - 'command-r-plus', - 'command-light', - 'command-light-nightly', - 'command', - 'command-nightly', - ]; - break; - case 25: - localModels = [ - 'moonshot-v1-8k', - 'moonshot-v1-32k', - 'moonshot-v1-128k', - ]; - break; - case 26: - localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; - break; - case 31: - localModels = ['yi-34b-chat-0205', 'yi-34b-chat-200k', 'yi-vl-plus']; - break; case 2: localModels = [ 'mj_imagine', @@ -207,8 +120,14 @@ const EditChannel = (props) => { 'mj_pan', ]; break; + default: + localModels = getChannelModels(value); + break; } - setInputs((inputs) => ({ ...inputs, models: localModels })); + if (inputs.models.length === 0) { + setInputs((inputs) => ({ ...inputs, models: localModels })); + } + setBasicModels(localModels); } //setAutoBan }; @@ -244,6 +163,7 @@ const EditChannel = (props) => { } else { setAutoBan(true); } + setBasicModels(getChannelModels(data.type)); // console.log(data); } else { showError(message); @@ -312,6 +232,9 @@ const EditChannel = (props) => { loadChannel().then(() => {}); } else { setInputs(originInputs); + let localModels = getChannelModels(inputs.type); + setBasicModels(localModels); + setInputs((inputs) => ({ ...inputs, models: localModels })); } }, [props.editingChannel.id]); @@ -596,7 +519,7 @@ const EditChannel = (props) => { handleInputChange('models', basicModels); }} > - 填入基础模型 + 填入相关模型