feat(mcp): available for provider reloading

This commit is contained in:
Junyan Qin
2025-03-19 12:41:04 +08:00
parent 40275c3ef1
commit 5640dc332d
5 changed files with 37 additions and 4 deletions

View File

@@ -204,6 +204,8 @@ class Application:
case core_entities.LifecycleControlScope.PROVIDER.value: case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info("执行热重载 scope="+scope) self.logger.info("执行热重载 scope="+scope)
await self.tool_mgr.shutdown()
latest_llm_model_config = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json") latest_llm_model_config = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
self.llm_models_meta = latest_llm_model_config self.llm_models_meta = latest_llm_model_config
llm_model_mgr_inst = llm_model_mgr.ModelManager(self) llm_model_mgr_inst = llm_model_mgr.ModelManager(self)

View File

@@ -47,3 +47,8 @@ class ToolLoader(abc.ABC):
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行工具调用""" """执行工具调用"""
pass pass
@abc.abstractmethod
async def shutdown(self):
"""关闭工具"""
pass

View File

@@ -31,7 +31,10 @@ class RuntimeMCPSession:
self.server_config = server_config self.server_config = server_config
self.ap = ap self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
self.functions = []
async def _init_stdio_python_server(self): async def _init_stdio_python_server(self):
server_params = StdioServerParameters( server_params = StdioServerParameters(
@@ -101,6 +104,10 @@ class RuntimeMCPSession:
func=func, func=func,
)) ))
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class("mcp") @loader.loader_class("mcp")
class MCPLoader(loader.ToolLoader): class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。 """MCP 工具加载器。
@@ -112,6 +119,11 @@ class MCPLoader(loader.ToolLoader):
_last_listed_functions: list[tools_entities.LLMFunction] = [] _last_listed_functions: list[tools_entities.LLMFunction] = []
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
async def initialize(self): async def initialize(self):
for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []): for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []):
@@ -125,7 +137,7 @@ class MCPLoader(loader.ToolLoader):
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
all_functions = [] all_functions = []
for server_name, session in self.sessions.items(): for session in self.sessions.values():
all_functions.extend(session.functions) all_functions.extend(session.functions)
self._last_listed_functions = all_functions self._last_listed_functions = all_functions
@@ -142,3 +154,8 @@ class MCPLoader(loader.ToolLoader):
return await function.func(query, **parameters) return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}") raise ValueError(f"未找到工具: {name}")
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()

View File

@@ -86,3 +86,7 @@ class PluginToolLoader(loader.ToolLoader):
function_name=function.name, function_name=function.name,
function_description=function.description, function_description=function.description,
) )
async def shutdown(self):
"""关闭工具"""
pass

View File

@@ -101,3 +101,8 @@ class ToolManager:
return await loader.invoke_tool(query, name, parameters) return await loader.invoke_tool(query, name, parameters)
else: else:
raise ValueError(f"未找到工具: {name}") raise ValueError(f"未找到工具: {name}")
async def shutdown(self):
"""关闭所有工具"""
for loader in self.loaders:
await loader.shutdown()