diff --git a/pkg/api/http/service/mcp.py b/pkg/api/http/service/mcp.py index 3766e7d6..328b9c20 100644 --- a/pkg/api/http/service/mcp.py +++ b/pkg/api/http/service/mcp.py @@ -72,6 +72,7 @@ class MCPService: ) old_server = result.first() old_server_name = old_server.name if old_server else None + old_enable = old_server.enable if old_server else False await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_mcp.MCPServer) @@ -80,18 +81,38 @@ class MCPService: ) if self.ap.tool_mgr.mcp_tool_loader: - if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: - await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) + new_enable = server_data.get('enable', False) - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) - ) - updated_server = result.first() - if updated_server: - # convert entity to config dict - server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server) - task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config)) - self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task) + need_remove = old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions + need_start = new_enable + + + if old_enable and not new_enable: + if need_remove: + await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) + + elif not old_enable and new_enable: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + updated_server = result.first() + if updated_server: + server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server) + task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config)) + self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task) + + elif old_enable and new_enable: + if need_remove: + await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name) + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) + ) + updated_server = result.first() + if updated_server: + server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server) + task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config)) + self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task) + async def delete_mcp_server(self, server_uuid: str) -> None: result = await self.ap.persistence_mgr.execute_async( diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index edff9e01..99e3021d 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -43,7 +43,12 @@ class RuntimeMCPSession: # connected: bool status: MCPSessionStatus - last_test_error_message: str + + _lifecycle_task: asyncio.Task | None + + _shutdown_event: asyncio.Event + + _ready_event: asyncio.Event def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application): self.server_name = server_name @@ -56,7 +61,10 @@ class RuntimeMCPSession: self.functions = [] self.status = MCPSessionStatus.CONNECTING - self.last_test_error_message = '' + + self._lifecycle_task = None + self._shutdown_event = asyncio.Event() + self._ready_event = asyncio.Event() async def _init_stdio_python_server(self): server_params = StdioServerParameters( @@ -89,10 +97,8 @@ class RuntimeMCPSession: await self.session.initialize() - async def start(self): - if not self.enable: - return - + async def _lifecycle_loop(self): + """在后台任务中管理整个MCP会话的生命周期""" try: if self.server_config['mode'] == 'stdio': await self._init_stdio_python_server() @@ -104,11 +110,45 @@ class RuntimeMCPSession: await self.refresh() self.status = MCPSessionStatus.CONNECTED - self.last_test_error_message = '' + + # 通知start()方法连接已建立 + self._ready_event.set() + + # 等待shutdown信号 + await self._shutdown_event.wait() + except Exception as e: self.status = MCPSessionStatus.ERROR - self.last_test_error_message = str(e) - raise e + self.ap.logger.error(f'Error in MCP session lifecycle {self.server_name}: {e}\n{traceback.format_exc()}') + # 即使出错也要设置ready事件,让start()方法知道初始化已完成 + self._ready_event.set() + finally: + # 在同一个任务中清理所有资源 + try: + if self.exit_stack: + await self.exit_stack.aclose() + self.functions.clear() + self.session = None + except Exception as e: + self.ap.logger.error(f'Error cleaning up MCP session {self.server_name}: {e}\n{traceback.format_exc()}') + + async def start(self): + if not self.enable: + return + + # 创建后台任务来管理生命周期 + self._lifecycle_task = asyncio.create_task(self._lifecycle_loop()) + + # 等待连接建立或失败(带超时) + try: + await asyncio.wait_for(self._ready_event.wait(), timeout=30.0) + except asyncio.TimeoutError: + self.status = MCPSessionStatus.ERROR + raise Exception('Connection timeout after 30 seconds') + + # 检查是否有错误 + if self.status == MCPSessionStatus.ERROR: + raise Exception('Connection failed, please check URL') async def refresh(self): self.functions.clear() @@ -143,7 +183,6 @@ class RuntimeMCPSession: def get_runtime_info_dict(self) -> dict: return { 'status': self.status.value, - 'error_message': self.last_test_error_message, 'tool_count': len(self.get_tools()), 'tools': [ { @@ -157,10 +196,22 @@ class RuntimeMCPSession: async def shutdown(self): """关闭会话并清理资源""" try: - if self.exit_stack: - await self.exit_stack.aclose() - self.functions.clear() - self.session = None + # 设置shutdown事件,通知lifecycle任务退出 + self._shutdown_event.set() + + # 等待lifecycle任务完成(带超时) + if self._lifecycle_task and not self._lifecycle_task.done(): + try: + await asyncio.wait_for(self._lifecycle_task, timeout=5.0) + except asyncio.TimeoutError: + self.ap.logger.warning(f'MCP session {self.server_name} shutdown timeout, cancelling task') + self._lifecycle_task.cancel() + try: + await self._lifecycle_task + except asyncio.CancelledError: + pass + + self.ap.logger.info(f'MCP session {self.server_name} shutdown complete') except Exception as e: self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}') diff --git a/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx b/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx index fd19cd4b..937da834 100644 --- a/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx +++ b/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx @@ -134,7 +134,7 @@ export default function MCPCardComponent({