feat: 填入相关模型

This commit is contained in:
CaIon 2024-05-12 19:07:33 +08:00
parent d8c006046f
commit 2dbf50dc07
11 changed files with 111 additions and 115 deletions

View File

@ -208,6 +208,8 @@ const (
ChannelTypeLingYiWanWu = 31 ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33 ChannelTypeAws = 33
ChannelTypeCohere = 34 ChannelTypeCohere = 34
ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{

View File

@ -3,14 +3,17 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"log"
"net/http" "net/http"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
) )
@ -43,6 +46,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() { func init() {
var permission []OpenAIModelPermission var permission []OpenAIModelPermission
@ -85,7 +89,7 @@ func init() {
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: "360", OwnedBy: ai360.ChannelName,
Permission: permission, Permission: permission,
Root: modelName, Root: modelName,
Parent: nil, Parent: nil,
@ -128,6 +132,18 @@ func init() {
for _, model := range openAIModels { for _, model := range openAIModels {
openAIModelsMap[model.Id] = model openAIModelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType := relayconstant.ChannelType2APIType(i)
if apiType == -1 || apiType == relayconstant.APITypeAIProxyLibrary {
continue
}
log.Println(apiType)
meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
channelId2Models[i] = adaptor.GetModelList()
}
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
@ -148,18 +164,25 @@ func ListModels(c *gin.Context) {
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"object": "list", "success": true,
"data": userOpenAiModels, "data": userOpenAiModels,
}) })
} }
func ChannelListModels(c *gin.Context) { func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"object": "list", "success": true,
"data": openAIModels, "data": openAIModels,
}) })
} }
func DashboardListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": channelId2Models,
})
}
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelId := c.Param("model") modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := openAIModelsMap[modelId]; ok {

View File

@ -6,3 +6,5 @@ var ModelList = []string{
"embedding_s1_v1", "embedding_s1_v1",
"semantic_similarity_s1_v1", "semantic_similarity_s1_v1",
} }
var ChannelName = "ai360"

View File

@ -1,5 +1,7 @@
package ollama package ollama
var ModelList []string var ModelList = []string{
"llama3-7b",
}
var ChannelName = "ollama" var ChannelName = "ollama"

View File

@ -6,7 +6,7 @@ var ModelList = []string{
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@ -25,8 +25,18 @@ const (
) )
func ChannelType2APIType(channelType int) int { func ChannelType2APIType(channelType int) int {
apiType := APITypeOpenAI apiType := -1
switch channelType { switch channelType {
case common.ChannelTypeOpenAI:
apiType = APITypeOpenAI
case common.ChannelTypeAzure:
apiType = APITypeOpenAI
case common.ChannelTypeMoonshot:
apiType = APITypeOpenAI
case common.ChannelTypeLingYiWanWu:
apiType = APITypeOpenAI
case common.ChannelType360:
apiType = APITypeOpenAI
case common.ChannelTypeAnthropic: case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic apiType = APITypeAnthropic
case common.ChannelTypeBaidu: case common.ChannelTypeBaidu:

View File

@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.GlobalAPIRateLimit())
{ {
apiRouter.GET("/status", controller.GetStatus) apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/about", controller.GetAbout)

View File

@ -31,6 +31,7 @@ import {
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import EditChannel from '../pages/Channel/EditChannel'; import EditChannel from '../pages/Channel/EditChannel';
import { IconTreeTriangleDown } from '@douyinfe/semi-icons'; import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
import { loadChannelModels } from './utils.js';
function renderTimestamp(timestamp) { function renderTimestamp(timestamp) {
return <>{timestamp2string(timestamp)}</>; return <>{timestamp2string(timestamp)}</>;
@ -354,27 +355,29 @@ const ChannelsTable = () => {
}; };
const copySelectedChannel = async (id) => { const copySelectedChannel = async (id) => {
const channelToCopy = channels.find(channel => String(channel.id) === String(id)); const channelToCopy = channels.find(
console.log(channelToCopy) (channel) => String(channel.id) === String(id),
);
console.log(channelToCopy);
channelToCopy.name += '_复制'; channelToCopy.name += '_复制';
channelToCopy.created_time = null; channelToCopy.created_time = null;
channelToCopy.balance = 0; channelToCopy.balance = 0;
channelToCopy.used_quota = 0; channelToCopy.used_quota = 0;
if (!channelToCopy) { if (!channelToCopy) {
showError("渠道未找到,请刷新页面后重试。"); showError('渠道未找到,请刷新页面后重试。');
return; return;
} }
try { try {
const newChannel = {...channelToCopy, id: undefined}; const newChannel = { ...channelToCopy, id: undefined };
const response = await API.post('/api/channel/', newChannel); const response = await API.post('/api/channel/', newChannel);
if (response.data.success) { if (response.data.success) {
showSuccess("渠道复制成功"); showSuccess('渠道复制成功');
await refresh(); await refresh();
} else { } else {
showError(response.data.message); showError(response.data.message);
} }
} catch (error) { } catch (error) {
showError("渠道复制失败: " + error.message); showError('渠道复制失败: ' + error.message);
} }
}; };
@ -395,6 +398,7 @@ const ChannelsTable = () => {
showError(reason); showError(reason);
}); });
fetchGroups().then(); fetchGroups().then();
loadChannelModels().then();
}, []); }, []);
const manageChannel = async (id, action, record, value) => { const manageChannel = async (id, action, record, value) => {

View File

@ -18,3 +18,32 @@ export async function onGitHubOAuthClicked(github_client_id) {
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`, `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
); );
} }
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');
const { success, data } = res.data;
if (!success) {
return;
}
channelModels = data;
localStorage.setItem('channel_models', JSON.stringify(data));
}
export function getChannelModels(type) {
if (channelModels !== undefined && type in channelModels) {
if (!channelModels[type]) {
return [];
}
return channelModels[type];
}
let models = localStorage.getItem('channel_models');
if (!models) {
return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return [];
}

View File

@ -86,13 +86,13 @@ export const CHANNEL_OPTIONS = [
label: '智谱 ChatGLM', label: '智谱 ChatGLM',
}, },
{ {
key: 16, key: 26,
text: '智谱 GLM-4V', text: '智谱 GLM-4V',
value: 26, value: 26,
color: 'purple', color: 'purple',
label: '智谱 GLM-4V', label: '智谱 GLM-4V',
}, },
{ key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' }, { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },

View File

@ -23,6 +23,7 @@ import {
Banner, Banner,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { Divider } from 'semantic-ui-react'; import { Divider } from 'semantic-ui-react';
import { getChannelModels, loadChannelModels } from '../../components/utils.js';
const MODEL_MAPPING_EXAMPLE = { const MODEL_MAPPING_EXAMPLE = {
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
@ -87,97 +88,9 @@ const EditChannel = (props) => {
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const handleInputChange = (name, value) => { const handleInputChange = (name, value) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type' && inputs.models.length === 0) { if (name === 'type') {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 33:
case 14:
localModels = [
'claude-instant-1.2',
'claude-2',
'claude-2.0',
'claude-2.1',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
];
break;
case 11:
localModels = ['PaLM-2'];
break;
case 15:
localModels = [
'ERNIE-Bot',
'ERNIE-Bot-turbo',
'ERNIE-Bot-4',
'Embedding-V1',
];
break;
case 17:
localModels = [
'qwen-turbo',
'qwen-plus',
'qwen-max',
'qwen-max-longcontext',
'text-embedding-v1',
];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = [
'SparkDesk',
'SparkDesk-v1.1',
'SparkDesk-v2.1',
'SparkDesk-v3.1',
'SparkDesk-v3.5',
];
break;
case 19:
localModels = [
'360GPT_S2_V9',
'embedding-bert-512-v1',
'embedding_s1_v1',
'semantic_similarity_s1_v1',
];
break;
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = [
'gemini-1.0-pro-001',
'gemini-1.0-pro-vision-001',
'gemini-1.5-pro',
'gemini-1.5-pro-latest',
'gemini-pro',
'gemini-pro-vision',
];
break;
case 34:
localModels = [
'command-r',
'command-r-plus',
'command-light',
'command-light-nightly',
'command',
'command-nightly',
];
break;
case 25:
localModels = [
'moonshot-v1-8k',
'moonshot-v1-32k',
'moonshot-v1-128k',
];
break;
case 26:
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
break;
case 31:
localModels = ['yi-34b-chat-0205', 'yi-34b-chat-200k', 'yi-vl-plus'];
break;
case 2: case 2:
localModels = [ localModels = [
'mj_imagine', 'mj_imagine',
@ -207,9 +120,15 @@ const EditChannel = (props) => {
'mj_pan', 'mj_pan',
]; ];
break; break;
default:
localModels = getChannelModels(value);
break;
} }
if (inputs.models.length === 0) {
setInputs((inputs) => ({ ...inputs, models: localModels })); setInputs((inputs) => ({ ...inputs, models: localModels }));
} }
setBasicModels(localModels);
}
//setAutoBan //setAutoBan
}; };
@ -244,6 +163,7 @@ const EditChannel = (props) => {
} else { } else {
setAutoBan(true); setAutoBan(true);
} }
setBasicModels(getChannelModels(data.type));
// console.log(data); // console.log(data);
} else { } else {
showError(message); showError(message);
@ -312,6 +232,9 @@ const EditChannel = (props) => {
loadChannel().then(() => {}); loadChannel().then(() => {});
} else { } else {
setInputs(originInputs); setInputs(originInputs);
let localModels = getChannelModels(inputs.type);
setBasicModels(localModels);
setInputs((inputs) => ({ ...inputs, models: localModels }));
} }
}, [props.editingChannel.id]); }, [props.editingChannel.id]);
@ -596,7 +519,7 @@ const EditChannel = (props) => {
handleInputChange('models', basicModels); handleInputChange('models', basicModels);
}} }}
> >
填入基础模型 填入相关模型
</Button> </Button>
<Button <Button
type='secondary' type='secondary'