feat(mcp): available for provider reloading

This commit is contained in:
Junyan Qin
2025-03-19 12:41:04 +08:00
parent 40275c3ef1
commit 5640dc332d
5 changed files with 37 additions and 4 deletions

View File

@@ -204,6 +204,8 @@ class Application:
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info("执行热重载 scope="+scope)
await self.tool_mgr.shutdown()
latest_llm_model_config = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
self.llm_models_meta = latest_llm_model_config
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)

View File

@@ -46,4 +46,9 @@ class ToolLoader(abc.ABC):
@abc.abstractmethod
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
pass
@abc.abstractmethod
async def shutdown(self):
"""关闭工具"""
pass

View File

@@ -30,8 +30,11 @@ class RuntimeMCPSession:
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack()
self.functions = []
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
@@ -101,6 +104,10 @@ class RuntimeMCPSession:
func=func,
))
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class("mcp")
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
@@ -112,6 +119,11 @@ class MCPLoader(loader.ToolLoader):
_last_listed_functions: list[tools_entities.LLMFunction] = []
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []):
@@ -125,7 +137,7 @@ class MCPLoader(loader.ToolLoader):
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
all_functions = []
for server_name, session in self.sessions.items():
for session in self.sessions.values():
all_functions.extend(session.functions)
self._last_listed_functions = all_functions
@@ -141,4 +153,9 @@ class MCPLoader(loader.ToolLoader):
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}")
raise ValueError(f"未找到工具: {name}")
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()

View File

@@ -85,4 +85,8 @@ class PluginToolLoader(loader.ToolLoader):
},
function_name=function.name,
function_description=function.description,
)
)
async def shutdown(self):
"""关闭工具"""
pass

View File

@@ -100,4 +100,9 @@ class ToolManager:
if await loader.has_tool(name):
return await loader.invoke_tool(query, name, parameters)
else:
raise ValueError(f"未找到工具: {name}")
raise ValueError(f"未找到工具: {name}")
async def shutdown(self):
"""关闭所有工具"""
for loader in self.loaders:
await loader.shutdown()