From 5640dc332d01cb514b3966876d86f05d796ee5aa Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 19 Mar 2025 12:41:04 +0800 Subject: [PATCH] feat(mcp): available for provider reloading --- pkg/core/app.py | 2 ++ pkg/provider/tools/loader.py | 5 +++++ pkg/provider/tools/loaders/mcp.py | 21 +++++++++++++++++++-- pkg/provider/tools/loaders/plugin.py | 6 +++++- pkg/provider/tools/toolmgr.py | 7 ++++++- 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/core/app.py b/pkg/core/app.py index fd0c59a3..8fd36d63 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -204,6 +204,8 @@ class Application: case core_entities.LifecycleControlScope.PROVIDER.value: self.logger.info("执行热重载 scope="+scope) + await self.tool_mgr.shutdown() + latest_llm_model_config = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json") self.llm_models_meta = latest_llm_model_config llm_model_mgr_inst = llm_model_mgr.ModelManager(self) diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index 82e6440d..cae4a63f 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -46,4 +46,9 @@ class ToolLoader(abc.ABC): @abc.abstractmethod async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: """执行工具调用""" + pass + + @abc.abstractmethod + async def shutdown(self): + """关闭工具""" pass \ No newline at end of file diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 4d15bf60..a475f9b7 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -30,8 +30,11 @@ class RuntimeMCPSession: self.server_name = server_name self.server_config = server_config self.ap = ap + + self.session = None self.exit_stack = AsyncExitStack() + self.functions = [] async def _init_stdio_python_server(self): server_params = StdioServerParameters( @@ -101,6 +104,10 @@ class RuntimeMCPSession: func=func, )) + async def shutdown(self): + """关闭工具""" + await self.session._exit_stack.aclose() + @loader.loader_class("mcp") class MCPLoader(loader.ToolLoader): """MCP 工具加载器。 @@ -112,6 +119,11 @@ class MCPLoader(loader.ToolLoader): _last_listed_functions: list[tools_entities.LLMFunction] = [] + def __init__(self, ap: app.Application): + super().__init__(ap) + self.sessions = {} + self._last_listed_functions = [] + async def initialize(self): for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []): @@ -125,7 +137,7 @@ class MCPLoader(loader.ToolLoader): async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: all_functions = [] - for server_name, session in self.sessions.items(): + for session in self.sessions.values(): all_functions.extend(session.functions) self._last_listed_functions = all_functions @@ -141,4 +153,9 @@ class MCPLoader(loader.ToolLoader): if function.name == name: return await function.func(query, **parameters) - raise ValueError(f"未找到工具: {name}") \ No newline at end of file + raise ValueError(f"未找到工具: {name}") + + async def shutdown(self): + """关闭工具""" + for session in self.sessions.values(): + await session.shutdown() diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index da0bc555..08211334 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -85,4 +85,8 @@ class PluginToolLoader(loader.ToolLoader): }, function_name=function.name, function_description=function.description, - ) \ No newline at end of file + ) + + async def shutdown(self): + """关闭工具""" + pass diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 1688937d..64befd8c 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -100,4 +100,9 @@ class ToolManager: if await loader.has_tool(name): return await loader.invoke_tool(query, name, parameters) else: - raise ValueError(f"未找到工具: {name}") \ No newline at end of file + raise ValueError(f"未找到工具: {name}") + + async def shutdown(self): + """关闭所有工具""" + for loader in self.loaders: + await loader.shutdown()