mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 00:06:04 +00:00
fix: try & catch & error
This commit is contained in:
@@ -10,9 +10,7 @@ from mcp.client.sse import sse_client
|
||||
|
||||
from .. import loader
|
||||
from ....core import app
|
||||
from ....entity.persistence import mcp as persistence_mcp
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class RuntimeMCPSession:
|
||||
@@ -113,8 +111,14 @@ class RuntimeMCPSession:
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
await self.session._exit_stack.aclose()
|
||||
"""关闭会话并清理资源"""
|
||||
try:
|
||||
if self.exit_stack:
|
||||
await self.exit_stack.aclose()
|
||||
self.functions.clear()
|
||||
self.session = None
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
|
||||
|
||||
|
||||
@loader.loader_class('mcp')
|
||||
@@ -134,46 +138,48 @@ class MCPLoader(loader.ToolLoader):
|
||||
self._last_listed_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
await self.load_mcp_servers_from_db()
|
||||
pass
|
||||
|
||||
async def load_mcp_servers_from_db(self):
|
||||
self.ap.logger.info('Loading MCP servers from db...')
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
servers = result.all()
|
||||
for server in servers:
|
||||
try:
|
||||
await self.load_mcp_server(server)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
|
||||
async def init_runtime_mcp_session(self, server_config: dict):
|
||||
"""从服务器配置创建运行时会话
|
||||
|
||||
async def init_runtime_mcp_session(
|
||||
self,
|
||||
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
|
||||
):
|
||||
if isinstance(server_entity, sqlalchemy.Row):
|
||||
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
|
||||
elif isinstance(server_entity, dict):
|
||||
server_entity = persistence_mcp.MCPServer(**server_entity)
|
||||
Args:
|
||||
server_config: 服务器配置字典,必须包含:
|
||||
- name: 服务器名称
|
||||
- mode: 连接模式 (stdio/sse)
|
||||
- enable: 是否启用
|
||||
- extra_args: 额外的配置参数 (可选)
|
||||
"""
|
||||
name = server_config['name']
|
||||
mode = server_config['mode']
|
||||
enable = server_config['enable']
|
||||
extra_args = server_config.get('extra_args', {})
|
||||
|
||||
mixed_config = {
|
||||
'name': server_entity.name,
|
||||
'mode': server_entity.mode,
|
||||
'enable': server_entity.enable,
|
||||
**server_entity.extra_args,
|
||||
'name': name,
|
||||
'mode': mode,
|
||||
'enable': enable,
|
||||
**extra_args,
|
||||
}
|
||||
|
||||
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
|
||||
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
|
||||
await session.initialize()
|
||||
|
||||
return session
|
||||
|
||||
async def load_mcp_server(
|
||||
self,
|
||||
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
|
||||
):
|
||||
session = await self.init_runtime_mcp_session(server_entity)
|
||||
async def load_mcp_server(self, server_config: dict):
|
||||
"""加载 MCP 服务器到运行时
|
||||
|
||||
Args:
|
||||
server_config: 服务器配置字典,必须包含:
|
||||
- name: 服务器名称
|
||||
- mode: 连接模式 (stdio/sse)
|
||||
- enable: 是否启用
|
||||
- extra_args: 额外的配置参数 (可选)
|
||||
"""
|
||||
session = await self.init_runtime_mcp_session(server_config)
|
||||
await session.start()
|
||||
self.sessions[server_entity.name] = session
|
||||
self.sessions[server_config['name']] = session
|
||||
|
||||
async def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
all_functions = []
|
||||
@@ -186,24 +192,91 @@ class MCPLoader(loader.ToolLoader):
|
||||
return all_functions
|
||||
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
return name in [f.name for f in self._last_listed_functions]
|
||||
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""检查工具是否存在"""
|
||||
for session in self.sessions.values():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
return await function.func(**parameters)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
for session in self.sessions.values():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
|
||||
try:
|
||||
result = await function.func(**parameters)
|
||||
self.ap.logger.debug(f'MCP tool {name} executed successfully')
|
||||
return result
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}')
|
||||
raise
|
||||
|
||||
raise ValueError(f'Tool not found: {name}')
|
||||
|
||||
async def reload_mcp_server(self, server_config: dict):
|
||||
"""重新加载 MCP 服务器(先移除再加载)
|
||||
|
||||
Args:
|
||||
server_config: 服务器配置字典,必须包含 name 字段
|
||||
"""
|
||||
server_name = server_config['name']
|
||||
|
||||
if server_name in self.sessions:
|
||||
await self.remove_mcp_server(server_name)
|
||||
|
||||
# 重新加载
|
||||
await self.load_mcp_server(server_config)
|
||||
|
||||
async def remove_mcp_server(self, server_name: str):
|
||||
"""移除 MCP 服务器"""
|
||||
if server_name not in self.sessions:
|
||||
raise ValueError(f'MCP server {server_name} not found')
|
||||
self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal')
|
||||
return
|
||||
|
||||
session = self.sessions.pop(server_name)
|
||||
await session.shutdown()
|
||||
self.ap.logger.info(f'Removed MCP server: {server_name}')
|
||||
|
||||
def get_session(self, server_name: str) -> RuntimeMCPSession | None:
|
||||
"""获取指定名称的 MCP 会话"""
|
||||
return self.sessions.get(server_name)
|
||||
|
||||
def has_session(self, server_name: str) -> bool:
|
||||
"""检查是否存在指定名称的 MCP 会话"""
|
||||
return server_name in self.sessions
|
||||
|
||||
def get_all_server_names(self) -> list[str]:
|
||||
"""获取所有已加载的 MCP 服务器名称"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
def get_server_tool_count(self, server_name: str) -> int:
|
||||
"""获取指定服务器的工具数量"""
|
||||
session = self.get_session(server_name)
|
||||
return len(session.functions) if session else 0
|
||||
|
||||
def get_all_servers_info(self) -> dict[str, dict]:
|
||||
"""获取所有服务器的信息"""
|
||||
info = {}
|
||||
for server_name, session in self.sessions.items():
|
||||
info[server_name] = {
|
||||
'name': server_name,
|
||||
'mode': session.server_config.get('mode'),
|
||||
'enable': session.enable,
|
||||
'tools_count': len(session.functions),
|
||||
'tool_names': [f.name for f in session.functions],
|
||||
}
|
||||
return info
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
for session in self.sessions.values():
|
||||
await session.shutdown()
|
||||
"""关闭所有工具"""
|
||||
self.ap.logger.info('Shutting down all MCP sessions...')
|
||||
for server_name, session in list(self.sessions.items()):
|
||||
try:
|
||||
await session.shutdown()
|
||||
self.ap.logger.debug(f'Shutdown MCP session: {server_name}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}')
|
||||
self.sessions.clear()
|
||||
self.ap.logger.info('All MCP sessions shutdown complete')
|
||||
|
||||
Reference in New Issue
Block a user