from __future__ import annotations import typing from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client from .. import loader, entities as tools_entities from ....core import app, entities as core_entities class RuntimeMCPSession: """运行时 MCP 会话""" ap: app.Application server_name: str server_config: dict session: ClientSession exit_stack: AsyncExitStack functions: list[tools_entities.LLMFunction] = [] def __init__(self, server_name: str, server_config: dict, ap: app.Application): self.server_name = server_name self.server_config = server_config self.ap = ap self.session = None self.exit_stack = AsyncExitStack() self.functions = [] async def _init_stdio_python_server(self): server_params = StdioServerParameters( command=self.server_config['command'], args=self.server_config['args'], env=self.server_config['env'], ) stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) stdio, write = stdio_transport self.session = await self.exit_stack.enter_async_context( ClientSession(stdio, write) ) await self.session.initialize() async def _init_sse_server(self): sse_transport = await self.exit_stack.enter_async_context( sse_client( self.server_config['url'], headers=self.server_config.get('headers', {}), timeout=self.server_config.get('timeout', 10), ) ) sseio, write = sse_transport self.session = await self.exit_stack.enter_async_context( ClientSession(sseio, write) ) await self.session.initialize() async def initialize(self): self.ap.logger.debug( f'初始化 MCP 会话: {self.server_name} {self.server_config}' ) if self.server_config['mode'] == 'stdio': await self._init_stdio_python_server() elif self.server_config['mode'] == 'sse': await self._init_sse_server() else: raise ValueError( f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}' ) tools = await self.session.list_tools() self.ap.logger.debug(f'获取 MCP 工具: {tools}') for tool in tools.tools: async def func(query: core_entities.Query, **kwargs): result = await self.session.call_tool(tool.name, kwargs) if result.isError: raise Exception(result.content[0].text) return result.content[0].text func.__name__ = tool.name self.functions.append( tools_entities.LLMFunction( name=tool.name, human_desc=tool.description, description=tool.description, parameters=tool.inputSchema, func=func, ) ) async def shutdown(self): """关闭工具""" await self.session._exit_stack.aclose() @loader.loader_class('mcp') class MCPLoader(loader.ToolLoader): """MCP 工具加载器。 在此加载器中管理所有与 MCP Server 的连接。 """ sessions: dict[str, RuntimeMCPSession] = {} _last_listed_functions: list[tools_entities.LLMFunction] = [] def __init__(self, ap: app.Application): super().__init__(ap) self.sessions = {} self._last_listed_functions = [] async def initialize(self): for server_config in self.ap.instance_config.data.get('mcp', {}).get( 'servers', [] ): if not server_config['enable']: continue session = RuntimeMCPSession(server_config['name'], server_config, self.ap) await session.initialize() # self.ap.event_loop.create_task(session.initialize()) self.sessions[server_config['name']] = session async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]: all_functions = [] for session in self.sessions.values(): all_functions.extend(session.functions) self._last_listed_functions = all_functions return all_functions async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] async def invoke_tool( self, query: core_entities.Query, name: str, parameters: dict ) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: return await function.func(query, **parameters) raise ValueError(f'未找到工具: {name}') async def shutdown(self): """关闭工具""" for session in self.sessions.values(): await session.shutdown()