feat(mcp): available for provider reloading

This commit is contained in:
Junyan Qin
2025-03-19 12:41:04 +08:00
parent 40275c3ef1
commit 5640dc332d
5 changed files with 37 additions and 4 deletions

View File

@@ -30,8 +30,11 @@ class RuntimeMCPSession:
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack()
self.functions = []
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
@@ -101,6 +104,10 @@ class RuntimeMCPSession:
func=func,
))
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class("mcp")
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
@@ -112,6 +119,11 @@ class MCPLoader(loader.ToolLoader):
_last_listed_functions: list[tools_entities.LLMFunction] = []
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []):
@@ -125,7 +137,7 @@ class MCPLoader(loader.ToolLoader):
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
all_functions = []
for server_name, session in self.sessions.items():
for session in self.sessions.values():
all_functions.extend(session.functions)
self._last_listed_functions = all_functions
@@ -141,4 +153,9 @@ class MCPLoader(loader.ToolLoader):
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}")
raise ValueError(f"未找到工具: {name}")
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()

View File

@@ -85,4 +85,8 @@ class PluginToolLoader(loader.ToolLoader):
},
function_name=function.name,
function_description=function.description,
)
)
async def shutdown(self):
"""关闭工具"""
pass