from __future__ import annotations import sqlalchemy import uuid import traceback from ....core import app from ....entity.persistence import mcp as persistence_mcp from ....core import taskmgr from ....provider.tools.loaders.mcp import RuntimeMCPSession class RuntimeMCPServer: """Runtime MCP Server representation""" ap: app.Application mcp_server_entity: persistence_mcp.MCPServer session: RuntimeMCPSession | None = None def __init__(self, ap: app.Application, mcp_server_entity: persistence_mcp.MCPServer): self.ap = ap self.mcp_server_entity = mcp_server_entity self.session = None async def initialize(self): """初始化 MCP Server""" if not self.mcp_server_entity.enable: return # 构建配置字典 mixed_config = { 'name': self.mcp_server_entity.name, 'mode': self.mcp_server_entity.mode, 'enable': self.mcp_server_entity.enable, **self.mcp_server_entity.extra_args, } self.session = RuntimeMCPSession( self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap ) await self.session.initialize() await self.session.start() async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext): """测试MCP服务器连接""" try: task_context.set_current_action(f'Testing connection to {self.mcp_server_entity.name}') # 创建临时会话进行测试 mixed_config = { 'name': self.mcp_server_entity.name, 'mode': self.mcp_server_entity.mode, 'enable': True, # 测试时强制启用 **self.mcp_server_entity.extra_args, } test_session = RuntimeMCPSession( self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap ) await test_session.start() # 获取工具列表作为测试 tools_count = len(test_session.functions) tool_name_list = [] for function in test_session.functions: tool_name_list.append(function.name) task_context.set_current_action(f'Successfully connected. Found {tools_count} tools.') # 关闭测试会话 await test_session.shutdown() return {'status': 'success', 'tools_count': tools_count, 'tools_names_lists': tool_name_list} except Exception as e: self.ap.logger.error(f'Connection test failed: {str(e)}\n{traceback.format_exc()}') task_context.set_current_action(f'Connection test failed: {str(e)}') raise e async def test_connection(self) -> str: """测试 MCP 服务器连接并返回任务 ID""" ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( self._test_mcp_server_task(task_context=ctx), kind='mcp-operation', name=f'mcp-test-{self.mcp_server_entity.name}', label=f'Testing MCP server {self.mcp_server_entity.name}', context=ctx, ) return wrapper.id async def dispose(self): """清理资源""" if self.session: await self.session.shutdown() class MCPService: ap: app.Application def __init__(self, ap: app.Application) -> None: self.ap = ap def _convert_server_entity_to_config( self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] ) -> dict: """将数据库实体转换为 loader 需要的配置字典 Args: server_entity: 数据库查询返回的服务器实体或 Row 对象 Returns: 包含服务器配置的字典 """ if isinstance(server_entity, sqlalchemy.Row): server = persistence_mcp.MCPServer(**server_entity._mapping) else: server = server_entity return { 'name': server.name, 'mode': server.mode, 'enable': server.enable, 'extra_args': server.extra_args, } async def initialize(self) -> None: """初始化 MCP Service,从数据库加载所有 MCP 服务器到运行时""" self.ap.logger.info('Initializing MCP Service and loading servers from database...') if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader: self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization') return try: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) servers = result.all() loaded_count = 0 failed_count = 0 for server in servers: try: # 将数据库实体转换为配置字典后传递给 loader server_config = self._convert_server_entity_to_config(server) await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) loaded_count += 1 self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}') except Exception as e: failed_count += 1 server_name = getattr(server, 'name', 'unknown') self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}') self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}') except Exception as e: self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}') async def get_mcp_servers(self) -> list[dict]: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) servers = result.all() return [self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers] async def create_mcp_server(self, server_data: dict) -> str: server_data['uuid'] = str(uuid.uuid4()) await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data)) result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid']) ) server_entity = result.first() if server_entity and self.ap.tool_mgr.mcp_tool_loader: server_config = self._convert_server_entity_to_config(server_entity) await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) return server_data['uuid'] async def get_mcp_server(self, server_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) server = result.first() if server is None: return None return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) async def get_mcp_server_by_name(self, server_name: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.name == server_name) ) server = result.first() if server is None: return None return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) old_server = result.first() old_server_name = old_server.name if old_server else None await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_mcp.MCPServer) .where(persistence_mcp.MCPServer.uuid == server_uuid) .values(server_data) ) if self.ap.tool_mgr.mcp_tool_loader: if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) updated_server = result.first() if updated_server: # convert entity to config dict server_config = self._convert_server_entity_to_config(updated_server) await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config) async def delete_mcp_server(self, server_uuid: str) -> None: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) server = result.first() server_name = server.name if server else None await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) if server_name and self.ap.tool_mgr.mcp_tool_loader: if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name) async def test_mcp_server(self, server_uuid: str) -> str: """测试 MCP 服务器连接并返回任务 ID""" result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) ) server = result.first() if server is None: raise ValueError(f'Server not found: {server_uuid}') if isinstance(server, sqlalchemy.Row): server_entity = persistence_mcp.MCPServer(**server._mapping) else: server_entity = server runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity) return await runtime_server.test_connection()