diff --git a/pkg/api/http/controller/groups/resources/mcp.py b/pkg/api/http/controller/groups/resources/mcp.py index b3bb18b5..bd22062a 100644 --- a/pkg/api/http/controller/groups/resources/mcp.py +++ b/pkg/api/http/controller/groups/resources/mcp.py @@ -39,9 +39,11 @@ class MCPRouterGroup(group.RouterGroup): data = await quart.request.json data = data['source'] - uuid = await self.ap.mcp_service.create_mcp_server(data) - - return self.success(data={'uuid': uuid}) + try: + uuid = await self.ap.mcp_service.create_mcp_server(data) + return self.success(data={'uuid': uuid}) + except Exception as e: + return self.http_status(500, -1, f'Failed to create MCP server: {str(e)}') @self.route('/servers/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN) async def _(server_name: str) -> str: @@ -56,12 +58,18 @@ class MCPRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': data = await quart.request.json - await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data) - return self.success() + try: + await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data) + return self.success() + except Exception as e: + return self.http_status(500, -1, f'Failed to update MCP server: {str(e)}') elif quart.request.method == 'DELETE': - await self.ap.mcp_service.delete_mcp_server(server_data['uuid']) - return self.success() + try: + await self.ap.mcp_service.delete_mcp_server(server_data['uuid']) + return self.success() + except Exception as e: + return self.http_status(500, -1, f'Failed to delete MCP server: {str(e)}') @self.route('/servers//test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _(server_name: str) -> str: @@ -71,49 +79,6 @@ class MCPRouterGroup(group.RouterGroup): if server_data is None: return self.http_status(404, -1, 'Server not found') - -# TODO 这里移到service去 -# # 创建测试任务 -# ctx = taskmgr.TaskContext.new() -# wrapper = self.ap.task_mgr.create_user_task( -# self._test_mcp_server(server, ctx), -# kind='mcp-operation', -# name=f'mcp-test-{server_name}', -# label=f'Testing MCP server {server_name}', -# context=ctx, -# ) -# return self.success(data={'task_id': wrapper.id}) - -# async def _test_mcp_server(self, server: persistence_mcp.MCPServer, ctx: taskmgr.TaskContext): -# """测试MCP服务器连接""" -# try: - -# ctx.current_action = f'Testing connection to {server.name}' -# # 创建临时会话进行测试 -# session = RuntimeMCPSession(server.name, { -# 'name': server.name, -# 'mode': server.mode, -# 'enable': server.enable, -# 'url': server.extra_args.get('url',''), -# 'headers': server.extra_args.get('headers',{}), -# 'timeout': server.extra_args.get('timeout',60), -# },enable=True, ap=self.ap) -# await session.start() - -# # 获取工具列表作为测试 -# tools_count = len(session.functions) - -# tool_name_list = [] -# for function in session.functions: -# tool_name_list.append(function.name) -# ctx.current_action = f'Successfully connected. Found {tools_count} tools.' - -# # 关闭测试会话 -# await session.shutdown() - -# return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list} - -# except Exception as e: -# print(traceback.format_exc()) -# ctx.current_action = f'Connection test failed: {str(e)}' -# raise e + + task_id = await self.ap.mcp_service.test_mcp_server(server_data['uuid']) + return self.success(data={'task_id': task_id}) diff --git a/pkg/api/http/service/mcp.py b/pkg/api/http/service/mcp.py index 3edf4123..4250d756 100644 --- a/pkg/api/http/service/mcp.py +++ b/pkg/api/http/service/mcp.py @@ -2,9 +2,107 @@ from __future__ import annotations import sqlalchemy import uuid +import traceback from ....core import app from ....entity.persistence import mcp as persistence_mcp +from ....core import taskmgr +from ....provider.tools.loaders.mcp import RuntimeMCPSession + + +class RuntimeMCPServer: + """Runtime MCP Server representation""" + + ap: app.Application + + mcp_server_entity: persistence_mcp.MCPServer + + session: RuntimeMCPSession | None = None + + def __init__(self, ap: app.Application, mcp_server_entity: persistence_mcp.MCPServer): + self.ap = ap + self.mcp_server_entity = mcp_server_entity + self.session = None + + async def initialize(self): + """初始化 MCP Server""" + if not self.mcp_server_entity.enable: + return + + # 构建配置字典 + mixed_config = { + 'name': self.mcp_server_entity.name, + 'mode': self.mcp_server_entity.mode, + 'enable': self.mcp_server_entity.enable, + **self.mcp_server_entity.extra_args, + } + + self.session = RuntimeMCPSession( + self.mcp_server_entity.name, + mixed_config, + self.mcp_server_entity.enable, + self.ap + ) + await self.session.initialize() + await self.session.start() + + async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext): + """测试MCP服务器连接""" + try: + task_context.set_current_action(f'Testing connection to {self.mcp_server_entity.name}') + + # 创建临时会话进行测试 + mixed_config = { + 'name': self.mcp_server_entity.name, + 'mode': self.mcp_server_entity.mode, + 'enable': True, # 测试时强制启用 + **self.mcp_server_entity.extra_args, + } + + test_session = RuntimeMCPSession( + self.mcp_server_entity.name, + mixed_config, + enable=True, + ap=self.ap + ) + await test_session.start() + + # 获取工具列表作为测试 + tools_count = len(test_session.functions) + + tool_name_list = [] + for function in test_session.functions: + tool_name_list.append(function.name) + + task_context.set_current_action(f'Successfully connected. Found {tools_count} tools.') + + # 关闭测试会话 + await test_session.shutdown() + + return {'status': 'success', 'tools_count': tools_count, 'tools_names_lists': tool_name_list} + + except Exception as e: + self.ap.logger.error(f'Connection test failed: {str(e)}\n{traceback.format_exc()}') + task_context.set_current_action(f'Connection test failed: {str(e)}') + raise e + + async def test_connection(self) -> str: + """测试 MCP 服务器连接并返回任务 ID""" + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._test_mcp_server_task(task_context=ctx), + kind='mcp-operation', + name=f'mcp-test-{self.mcp_server_entity.name}', + label=f'Testing MCP server {self.mcp_server_entity.name}', + context=ctx, + ) + return wrapper.id + + async def dispose(self): + """清理资源""" + if self.session: + await self.session.shutdown() + class MCPService: @@ -13,6 +111,61 @@ class MCPService: def __init__(self, ap: app.Application) -> None: self.ap = ap + def _convert_server_entity_to_config( + self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] + ) -> dict: + """将数据库实体转换为 loader 需要的配置字典 + + Args: + server_entity: 数据库查询返回的服务器实体或 Row 对象 + + Returns: + 包含服务器配置的字典 + """ + if isinstance(server_entity, sqlalchemy.Row): + server = persistence_mcp.MCPServer(**server_entity._mapping) + else: + server = server_entity + + return { + 'name': server.name, + 'mode': server.mode, + 'enable': server.enable, + 'extra_args': server.extra_args, + } + + async def initialize(self) -> None: + """初始化 MCP Service,从数据库加载所有 MCP 服务器到运行时""" + self.ap.logger.info('Initializing MCP Service and loading servers from database...') + + if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader: + self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization') + return + + try: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) + servers = result.all() + + loaded_count = 0 + failed_count = 0 + + for server in servers: + try: + # 将数据库实体转换为配置字典后传递给 loader + server_config = self._convert_server_entity_to_config(server) + await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) + loaded_count += 1 + self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}') + except Exception as e: + failed_count += 1 + + server_name = getattr(server, 'name', 'unknown') + self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}') + + self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}') + except Exception as e: + self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}') + async def get_mcp_servers(self) -> list[dict]: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) @@ -22,11 +175,16 @@ class MCPService: async def create_mcp_server(self, server_data: dict) -> str: server_data['uuid'] = str(uuid.uuid4()) await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data)) - server = await self.get_mcp_server(server_data['uuid']) - # TODO: load runtime mcp server session + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid']) + ) + server_entity = result.first() + if server_entity and self.ap.tool_mgr.mcp_tool_loader: + server_config = self._convert_server_entity_to_config(server_entity) + await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) - return server['uuid'] + return server_data['uuid'] async def get_mcp_server(self, server_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( @@ -47,17 +205,70 @@ class MCPService: return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None: + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + old_server = result.first() + old_server_name = old_server.name if old_server else None + + await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_mcp.MCPServer) .where(persistence_mcp.MCPServer.uuid == server_uuid) .values(server_data) ) - # TODO: reload runtime mcp server session + if self.ap.tool_mgr.mcp_tool_loader: + + if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: + await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) + + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + updated_server = result.first() + if updated_server: + # convert entity to config dict + server_config = self._convert_server_entity_to_config(updated_server) + await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) async def delete_mcp_server(self, server_uuid: str) -> None: + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + server = result.first() + server_name = server.name if server else None + + await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) - # TODO: remove runtime mcp server session + + if server_name and self.ap.tool_mgr.mcp_tool_loader: + if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: + await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name) + + async def test_mcp_server(self, server_uuid: str) -> str: + """测试 MCP 服务器连接并返回任务 ID""" + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + server = result.first() + if server is None: + raise ValueError(f'Server not found: {server_uuid}') + + + if isinstance(server, sqlalchemy.Row): + server_entity = persistence_mcp.MCPServer(**server._mapping) + else: + server_entity = server + + runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity) + + + return await runtime_server.test_connection() diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 8df32755..f2991315 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -129,6 +129,7 @@ class BuildAppStage(stage.BootingStage): mcp_service_inst = mcp_service.MCPService(ap) ap.mcp_service = mcp_service_inst + await mcp_service_inst.initialize() ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 721e0782..3f480f0b 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -10,9 +10,7 @@ from mcp.client.sse import sse_client from .. import loader from ....core import app -from ....entity.persistence import mcp as persistence_mcp import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import sqlalchemy class RuntimeMCPSession: @@ -113,8 +111,14 @@ class RuntimeMCPSession: ) async def shutdown(self): - """关闭工具""" - await self.session._exit_stack.aclose() + """关闭会话并清理资源""" + try: + if self.exit_stack: + await self.exit_stack.aclose() + self.functions.clear() + self.session = None + except Exception as e: + self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}') @loader.loader_class('mcp') @@ -134,46 +138,48 @@ class MCPLoader(loader.ToolLoader): self._last_listed_functions = [] async def initialize(self): - await self.load_mcp_servers_from_db() + pass - async def load_mcp_servers_from_db(self): - self.ap.logger.info('Loading MCP servers from db...') - result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) - servers = result.all() - for server in servers: - try: - await self.load_mcp_server(server) - except Exception as e: - self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}') + async def init_runtime_mcp_session(self, server_config: dict): + """从服务器配置创建运行时会话 - async def init_runtime_mcp_session( - self, - server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict, - ): - if isinstance(server_entity, sqlalchemy.Row): - server_entity = persistence_mcp.MCPServer(**server_entity._mapping) - elif isinstance(server_entity, dict): - server_entity = persistence_mcp.MCPServer(**server_entity) + Args: + server_config: 服务器配置字典,必须包含: + - name: 服务器名称 + - mode: 连接模式 (stdio/sse) + - enable: 是否启用 + - extra_args: 额外的配置参数 (可选) + """ + name = server_config['name'] + mode = server_config['mode'] + enable = server_config['enable'] + extra_args = server_config.get('extra_args', {}) mixed_config = { - 'name': server_entity.name, - 'mode': server_entity.mode, - 'enable': server_entity.enable, - **server_entity.extra_args, + 'name': name, + 'mode': mode, + 'enable': enable, + **extra_args, } - session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap) + session = RuntimeMCPSession(name, mixed_config, enable, self.ap) await session.initialize() return session - async def load_mcp_server( - self, - server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict, - ): - session = await self.init_runtime_mcp_session(server_entity) + async def load_mcp_server(self, server_config: dict): + """加载 MCP 服务器到运行时 + + Args: + server_config: 服务器配置字典,必须包含: + - name: 服务器名称 + - mode: 连接模式 (stdio/sse) + - enable: 是否启用 + - extra_args: 额外的配置参数 (可选) + """ + session = await self.init_runtime_mcp_session(server_config) await session.start() - self.sessions[server_entity.name] = session + self.sessions[server_config['name']] = session async def get_tools(self) -> list[resource_tool.LLMTool]: all_functions = [] @@ -186,24 +192,91 @@ class MCPLoader(loader.ToolLoader): return all_functions async def has_tool(self, name: str) -> bool: - return name in [f.name for f in self._last_listed_functions] - - async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: + """检查工具是否存在""" for session in self.sessions.values(): for function in session.functions: if function.name == name: - return await function.func(**parameters) + return True + return False + + async def invoke_tool(self, name: str, parameters: dict) -> typing.Any: + """执行工具调用""" + for session in self.sessions.values(): + for function in session.functions: + if function.name == name: + self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}') + try: + result = await function.func(**parameters) + self.ap.logger.debug(f'MCP tool {name} executed successfully') + return result + except Exception as e: + self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}') + raise raise ValueError(f'Tool not found: {name}') + async def reload_mcp_server(self, server_config: dict): + """重新加载 MCP 服务器(先移除再加载) + + Args: + server_config: 服务器配置字典,必须包含 name 字段 + """ + server_name = server_config['name'] + + if server_name in self.sessions: + await self.remove_mcp_server(server_name) + + # 重新加载 + await self.load_mcp_server(server_config) + async def remove_mcp_server(self, server_name: str): + """移除 MCP 服务器""" if server_name not in self.sessions: - raise ValueError(f'MCP server {server_name} not found') + self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal') + return session = self.sessions.pop(server_name) await session.shutdown() + self.ap.logger.info(f'Removed MCP server: {server_name}') + + def get_session(self, server_name: str) -> RuntimeMCPSession | None: + """获取指定名称的 MCP 会话""" + return self.sessions.get(server_name) + + def has_session(self, server_name: str) -> bool: + """检查是否存在指定名称的 MCP 会话""" + return server_name in self.sessions + + def get_all_server_names(self) -> list[str]: + """获取所有已加载的 MCP 服务器名称""" + return list(self.sessions.keys()) + + def get_server_tool_count(self, server_name: str) -> int: + """获取指定服务器的工具数量""" + session = self.get_session(server_name) + return len(session.functions) if session else 0 + + def get_all_servers_info(self) -> dict[str, dict]: + """获取所有服务器的信息""" + info = {} + for server_name, session in self.sessions.items(): + info[server_name] = { + 'name': server_name, + 'mode': session.server_config.get('mode'), + 'enable': session.enable, + 'tools_count': len(session.functions), + 'tool_names': [f.name for f in session.functions], + } + return info async def shutdown(self): - """关闭工具""" - for session in self.sessions.values(): - await session.shutdown() + """关闭所有工具""" + self.ap.logger.info('Shutting down all MCP sessions...') + for server_name, session in list(self.sessions.items()): + try: + await session.shutdown() + self.ap.logger.debug(f'Shutdown MCP session: {server_name}') + except Exception as e: + self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}') + self.sessions.clear() + self.ap.logger.info('All MCP sessions shutdown complete')