refactor: api

This commit is contained in:
Junyan Qin
2025-08-07 09:06:24 +08:00
parent 70ad92ca16
commit 0f35458cf7

View File

@@ -10,110 +10,114 @@ from .. import group
@group.group_class('mcp', '/api/v1/mcp')
class MCPRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/servers', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
@self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""获取MCP服务器列表"""
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.success(data={'servers': []})
if quart.request.method == 'GET':
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.success(data={'servers': []})
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', [])
# 获取每个服务器的状态和工具信息
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
# 获取每个服务器的状态和工具信息
mcp_loader = None
for loader_name, loader in self.ap.tool_mgr.loaders.items():
if loader_name == 'mcp':
mcp_loader = loader
break
servers_with_status = []
for server in servers:
server_info = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
'config': server,
'status': 'disconnected',
'tools': [],
'error': None,
servers_with_status = []
for server in servers:
server_info = {
'name': server['name'],
'mode': server['mode'],
'enable': server['enable'],
'config': server,
'status': 'disconnected',
'tools': [],
'error': None,
}
# 检查服务器连接状态
if mcp_loader and server['name'] in mcp_loader.sessions:
session = mcp_loader.sessions[server['name']]
server_info['status'] = 'connected'
server_info['tools'] = [
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
for func in session.functions
]
elif server['enable']:
server_info['status'] = 'error'
server_info['error'] = 'Failed to connect'
servers_with_status.append(server_info)
return self.success(data={'servers': servers_with_status})
elif quart.request.method == 'POST':
data = await quart.request.json
# 验证必填字段
required_fields = ['name', 'mode']
for field in required_fields:
if field not in data:
return self.http_status(400, -1, f'Missing required field: {field}')
# 检查provider_cfg是否可用
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
# 获取当前配置
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 检查服务器名称是否重复
for server in servers:
if server['name'] == data['name']:
return self.http_status(400, -1, 'Server name already exists')
# 创建新服务器配置
new_server = {
'name': data['name'],
'mode': data['mode'],
'enable': data.get('enable', True),
}
# 检查服务器连接状态
if mcp_loader and server['name'] in mcp_loader.sessions:
session = mcp_loader.sessions[server['name']]
server_info['status'] = 'connected'
server_info['tools'] = [
{'name': func.name, 'description': func.description, 'parameters': func.parameters}
for func in session.functions
]
elif server['enable']:
server_info['status'] = 'error'
server_info['error'] = 'Failed to connect'
# 根据模式添加配置
if data['mode'] == 'stdio':
new_server.update(
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
)
elif data['mode'] == 'sse':
new_server.update(
{
'url': data.get('url', ''),
'headers': data.get('headers', {}),
'timeout': data.get('timeout', 10),
}
)
servers_with_status.append(server_info)
# 添加到配置
servers.append(new_server)
self.ap.provider_cfg.data['mcp'] = mcp_config
return self.success(data={'servers': servers_with_status})
# 保存配置
await self.ap.provider_cfg.dump_config()
@self.route('/servers', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
"""创建MCP服务器配置"""
data = await quart.request.json
# 验证必填字段
required_fields = ['name', 'mode']
for field in required_fields:
if field not in data:
return self.http_status(400, -1, f'Missing required field: {field}')
# 检查provider_cfg是否可用
if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data:
return self.http_status(500, -1, 'Provider configuration not available')
# 获取当前配置
mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []})
servers = mcp_config['servers']
# 检查服务器名称是否重复
for server in servers:
if server['name'] == data['name']:
return self.http_status(400, -1, 'Server name already exists')
# 创建新服务器配置
new_server = {
'name': data['name'],
'mode': data['mode'],
'enable': data.get('enable', True),
}
# 根据模式添加配置
if data['mode'] == 'stdio':
new_server.update(
{'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})}
)
elif data['mode'] == 'sse':
new_server.update(
{'url': data.get('url', ''), 'headers': data.get('headers', {}), 'timeout': data.get('timeout', 10)}
)
# 添加到配置
servers.append(new_server)
self.ap.provider_cfg.data['mcp'] = mcp_config
# 保存配置
await self.ap.provider_cfg.dump_config()
# 如果启用尝试重新加载MCP loader
if new_server['enable']:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{new_server["name"]}',
label=f'Reloading MCP loader for {new_server["name"]}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
return self.success()
# 如果启用尝试重新加载MCP loader
if new_server['enable']:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._reload_mcp_loader(ctx),
kind='mcp-operation',
name=f'mcp-reload-{new_server["name"]}',
label=f'Reloading MCP loader for {new_server["name"]}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
else:
return self.success()
else:
return self.http_status(405, -1, 'Method not allowed')
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str: