fix: try & catch & error

This commit is contained in:
wangcham
2025-11-02 12:37:00 +00:00
parent 4c0917556f
commit c2d752f9e9
4 changed files with 349 additions and 99 deletions

View File

@@ -10,9 +10,7 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import sqlalchemy
class RuntimeMCPSession:
@@ -113,8 +111,14 @@ class RuntimeMCPSession:
)
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
"""关闭会话并清理资源"""
try:
if self.exit_stack:
await self.exit_stack.aclose()
self.functions.clear()
self.session = None
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
@loader.loader_class('mcp')
@@ -134,46 +138,48 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
await self.load_mcp_servers_from_db()
pass
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
try:
await self.load_mcp_server(server)
except Exception as e:
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
async def init_runtime_mcp_session(self, server_config: dict):
"""从服务器配置创建运行时会话
async def init_runtime_mcp_session(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
if isinstance(server_entity, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
elif isinstance(server_entity, dict):
server_entity = persistence_mcp.MCPServer(**server_entity)
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
name = server_config['name']
mode = server_config['mode']
enable = server_config['enable']
extra_args = server_config.get('extra_args', {})
mixed_config = {
'name': server_entity.name,
'mode': server_entity.mode,
'enable': server_entity.enable,
**server_entity.extra_args,
'name': name,
'mode': mode,
'enable': enable,
**extra_args,
}
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
await session.initialize()
return session
async def load_mcp_server(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
session = await self.init_runtime_mcp_session(server_entity)
async def load_mcp_server(self, server_config: dict):
"""加载 MCP 服务器到运行时
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
session = await self.init_runtime_mcp_session(server_config)
await session.start()
self.sessions[server_entity.name] = session
self.sessions[server_config['name']] = session
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
@@ -186,24 +192,91 @@ class MCPLoader(loader.ToolLoader):
return all_functions
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
"""检查工具是否存在"""
for session in self.sessions.values():
for function in session.functions:
if function.name == name:
return await function.func(**parameters)
return True
return False
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
for session in self.sessions.values():
for function in session.functions:
if function.name == name:
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
try:
result = await function.func(**parameters)
self.ap.logger.debug(f'MCP tool {name} executed successfully')
return result
except Exception as e:
self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}')
raise
raise ValueError(f'Tool not found: {name}')
async def reload_mcp_server(self, server_config: dict):
"""重新加载 MCP 服务器(先移除再加载)
Args:
server_config: 服务器配置字典,必须包含 name 字段
"""
server_name = server_config['name']
if server_name in self.sessions:
await self.remove_mcp_server(server_name)
# 重新加载
await self.load_mcp_server(server_config)
async def remove_mcp_server(self, server_name: str):
"""移除 MCP 服务器"""
if server_name not in self.sessions:
raise ValueError(f'MCP server {server_name} not found')
self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal')
return
session = self.sessions.pop(server_name)
await session.shutdown()
self.ap.logger.info(f'Removed MCP server: {server_name}')
def get_session(self, server_name: str) -> RuntimeMCPSession | None:
"""获取指定名称的 MCP 会话"""
return self.sessions.get(server_name)
def has_session(self, server_name: str) -> bool:
"""检查是否存在指定名称的 MCP 会话"""
return server_name in self.sessions
def get_all_server_names(self) -> list[str]:
"""获取所有已加载的 MCP 服务器名称"""
return list(self.sessions.keys())
def get_server_tool_count(self, server_name: str) -> int:
"""获取指定服务器的工具数量"""
session = self.get_session(server_name)
return len(session.functions) if session else 0
def get_all_servers_info(self) -> dict[str, dict]:
"""获取所有服务器的信息"""
info = {}
for server_name, session in self.sessions.items():
info[server_name] = {
'name': server_name,
'mode': session.server_config.get('mode'),
'enable': session.enable,
'tools_count': len(session.functions),
'tool_names': [f.name for f in session.functions],
}
return info
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()
"""关闭所有工具"""
self.ap.logger.info('Shutting down all MCP sessions...')
for server_name, session in list(self.sessions.items()):
try:
await session.shutdown()
self.ap.logger.debug(f'Shutdown MCP session: {server_name}')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}')
self.sessions.clear()
self.ap.logger.info('All MCP sessions shutdown complete')