From 0f35458cf7dd05941ab7c04c92295534b8701c92 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 7 Aug 2025 09:06:24 +0800 Subject: [PATCH] refactor: api --- pkg/api/http/controller/groups/mcp.py | 192 +++++++++++++------------- 1 file changed, 98 insertions(+), 94 deletions(-) diff --git a/pkg/api/http/controller/groups/mcp.py b/pkg/api/http/controller/groups/mcp.py index 21fcb530..2392c8a4 100644 --- a/pkg/api/http/controller/groups/mcp.py +++ b/pkg/api/http/controller/groups/mcp.py @@ -10,110 +10,114 @@ from .. import group @group.group_class('mcp', '/api/v1/mcp') class MCPRouterGroup(group.RouterGroup): async def initialize(self) -> None: - @self.route('/servers', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) + @self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: """获取MCP服务器列表""" - if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: - return self.success(data={'servers': []}) + if quart.request.method == 'GET': + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.success(data={'servers': []}) - servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', []) + servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', []) - # 获取每个服务器的状态和工具信息 - mcp_loader = None - for loader_name, loader in self.ap.tool_mgr.loaders.items(): - if loader_name == 'mcp': - mcp_loader = loader - break + # 获取每个服务器的状态和工具信息 + mcp_loader = None + for loader_name, loader in self.ap.tool_mgr.loaders.items(): + if loader_name == 'mcp': + mcp_loader = loader + break - servers_with_status = [] - for server in servers: - server_info = { - 'name': server['name'], - 'mode': server['mode'], - 'enable': server['enable'], - 'config': server, - 'status': 'disconnected', - 'tools': [], - 'error': None, + servers_with_status = [] + for server in servers: + server_info = { + 'name': server['name'], + 'mode': server['mode'], + 'enable': server['enable'], + 'config': server, + 'status': 'disconnected', + 'tools': [], + 'error': None, + } + + # 检查服务器连接状态 + if mcp_loader and server['name'] in mcp_loader.sessions: + session = mcp_loader.sessions[server['name']] + server_info['status'] = 'connected' + server_info['tools'] = [ + {'name': func.name, 'description': func.description, 'parameters': func.parameters} + for func in session.functions + ] + elif server['enable']: + server_info['status'] = 'error' + server_info['error'] = 'Failed to connect' + + servers_with_status.append(server_info) + + return self.success(data={'servers': servers_with_status}) + elif quart.request.method == 'POST': + data = await quart.request.json + + # 验证必填字段 + required_fields = ['name', 'mode'] + for field in required_fields: + if field not in data: + return self.http_status(400, -1, f'Missing required field: {field}') + + # 检查provider_cfg是否可用 + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.http_status(500, -1, 'Provider configuration not available') + + # 获取当前配置 + mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) + servers = mcp_config['servers'] + + # 检查服务器名称是否重复 + for server in servers: + if server['name'] == data['name']: + return self.http_status(400, -1, 'Server name already exists') + + # 创建新服务器配置 + new_server = { + 'name': data['name'], + 'mode': data['mode'], + 'enable': data.get('enable', True), } - # 检查服务器连接状态 - if mcp_loader and server['name'] in mcp_loader.sessions: - session = mcp_loader.sessions[server['name']] - server_info['status'] = 'connected' - server_info['tools'] = [ - {'name': func.name, 'description': func.description, 'parameters': func.parameters} - for func in session.functions - ] - elif server['enable']: - server_info['status'] = 'error' - server_info['error'] = 'Failed to connect' + # 根据模式添加配置 + if data['mode'] == 'stdio': + new_server.update( + {'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})} + ) + elif data['mode'] == 'sse': + new_server.update( + { + 'url': data.get('url', ''), + 'headers': data.get('headers', {}), + 'timeout': data.get('timeout', 10), + } + ) - servers_with_status.append(server_info) + # 添加到配置 + servers.append(new_server) + self.ap.provider_cfg.data['mcp'] = mcp_config - return self.success(data={'servers': servers_with_status}) + # 保存配置 + await self.ap.provider_cfg.dump_config() - @self.route('/servers', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) - async def _() -> str: - """创建MCP服务器配置""" - data = await quart.request.json - - # 验证必填字段 - required_fields = ['name', 'mode'] - for field in required_fields: - if field not in data: - return self.http_status(400, -1, f'Missing required field: {field}') - - # 检查provider_cfg是否可用 - if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: - return self.http_status(500, -1, 'Provider configuration not available') - - # 获取当前配置 - mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) - servers = mcp_config['servers'] - - # 检查服务器名称是否重复 - for server in servers: - if server['name'] == data['name']: - return self.http_status(400, -1, 'Server name already exists') - - # 创建新服务器配置 - new_server = { - 'name': data['name'], - 'mode': data['mode'], - 'enable': data.get('enable', True), - } - - # 根据模式添加配置 - if data['mode'] == 'stdio': - new_server.update( - {'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})} - ) - elif data['mode'] == 'sse': - new_server.update( - {'url': data.get('url', ''), 'headers': data.get('headers', {}), 'timeout': data.get('timeout', 10)} - ) - - # 添加到配置 - servers.append(new_server) - self.ap.provider_cfg.data['mcp'] = mcp_config - - # 保存配置 - await self.ap.provider_cfg.dump_config() - - # 如果启用,尝试重新加载MCP loader - if new_server['enable']: - ctx = taskmgr.TaskContext.new() - wrapper = self.ap.task_mgr.create_user_task( - self._reload_mcp_loader(ctx), - kind='mcp-operation', - name=f'mcp-reload-{new_server["name"]}', - label=f'Reloading MCP loader for {new_server["name"]}', - context=ctx, - ) - return self.success(data={'task_id': wrapper.id}) - - return self.success() + # 如果启用,尝试重新加载MCP loader + if new_server['enable']: + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._reload_mcp_loader(ctx), + kind='mcp-operation', + name=f'mcp-reload-{new_server["name"]}', + label=f'Reloading MCP loader for {new_server["name"]}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + else: + return self.success() + else: + return self.http_status(405, -1, 'Method not allowed') @self.route('/servers/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN) async def _(server_name: str) -> str: