fix: try & catch & error

This commit is contained in:
wangcham
2025-11-02 12:37:00 +00:00
parent 4c0917556f
commit c2d752f9e9
4 changed files with 349 additions and 99 deletions

View File

@@ -39,9 +39,11 @@ class MCPRouterGroup(group.RouterGroup):
data = await quart.request.json
data = data['source']
uuid = await self.ap.mcp_service.create_mcp_server(data)
return self.success(data={'uuid': uuid})
try:
uuid = await self.ap.mcp_service.create_mcp_server(data)
return self.success(data={'uuid': uuid})
except Exception as e:
return self.http_status(500, -1, f'Failed to create MCP server: {str(e)}')
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
@@ -56,12 +58,18 @@ class MCPRouterGroup(group.RouterGroup):
elif quart.request.method == 'PUT':
data = await quart.request.json
await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data)
return self.success()
try:
await self.ap.mcp_service.update_mcp_server(server_data['uuid'], data)
return self.success()
except Exception as e:
return self.http_status(500, -1, f'Failed to update MCP server: {str(e)}')
elif quart.request.method == 'DELETE':
await self.ap.mcp_service.delete_mcp_server(server_data['uuid'])
return self.success()
try:
await self.ap.mcp_service.delete_mcp_server(server_data['uuid'])
return self.success()
except Exception as e:
return self.http_status(500, -1, f'Failed to delete MCP server: {str(e)}')
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
@@ -71,49 +79,6 @@ class MCPRouterGroup(group.RouterGroup):
if server_data is None:
return self.http_status(404, -1, 'Server not found')
# TODO 这里移到service去
# # 创建测试任务
# ctx = taskmgr.TaskContext.new()
# wrapper = self.ap.task_mgr.create_user_task(
# self._test_mcp_server(server, ctx),
# kind='mcp-operation',
# name=f'mcp-test-{server_name}',
# label=f'Testing MCP server {server_name}',
# context=ctx,
# )
# return self.success(data={'task_id': wrapper.id})
# async def _test_mcp_server(self, server: persistence_mcp.MCPServer, ctx: taskmgr.TaskContext):
# """测试MCP服务器连接"""
# try:
# ctx.current_action = f'Testing connection to {server.name}'
# # 创建临时会话进行测试
# session = RuntimeMCPSession(server.name, {
# 'name': server.name,
# 'mode': server.mode,
# 'enable': server.enable,
# 'url': server.extra_args.get('url',''),
# 'headers': server.extra_args.get('headers',{}),
# 'timeout': server.extra_args.get('timeout',60),
# },enable=True, ap=self.ap)
# await session.start()
# # 获取工具列表作为测试
# tools_count = len(session.functions)
# tool_name_list = []
# for function in session.functions:
# tool_name_list.append(function.name)
# ctx.current_action = f'Successfully connected. Found {tools_count} tools.'
# # 关闭测试会话
# await session.shutdown()
# return {'status': 'success', 'tools_count': tools_count,'tools_names_lists':tool_name_list}
# except Exception as e:
# print(traceback.format_exc())
# ctx.current_action = f'Connection test failed: {str(e)}'
# raise e
task_id = await self.ap.mcp_service.test_mcp_server(server_data['uuid'])
return self.success(data={'task_id': task_id})

View File

@@ -2,9 +2,107 @@ from __future__ import annotations
import sqlalchemy
import uuid
import traceback
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
from ....core import taskmgr
from ....provider.tools.loaders.mcp import RuntimeMCPSession
class RuntimeMCPServer:
"""Runtime MCP Server representation"""
ap: app.Application
mcp_server_entity: persistence_mcp.MCPServer
session: RuntimeMCPSession | None = None
def __init__(self, ap: app.Application, mcp_server_entity: persistence_mcp.MCPServer):
self.ap = ap
self.mcp_server_entity = mcp_server_entity
self.session = None
async def initialize(self):
"""初始化 MCP Server"""
if not self.mcp_server_entity.enable:
return
# 构建配置字典
mixed_config = {
'name': self.mcp_server_entity.name,
'mode': self.mcp_server_entity.mode,
'enable': self.mcp_server_entity.enable,
**self.mcp_server_entity.extra_args,
}
self.session = RuntimeMCPSession(
self.mcp_server_entity.name,
mixed_config,
self.mcp_server_entity.enable,
self.ap
)
await self.session.initialize()
await self.session.start()
async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext):
"""测试MCP服务器连接"""
try:
task_context.set_current_action(f'Testing connection to {self.mcp_server_entity.name}')
# 创建临时会话进行测试
mixed_config = {
'name': self.mcp_server_entity.name,
'mode': self.mcp_server_entity.mode,
'enable': True, # 测试时强制启用
**self.mcp_server_entity.extra_args,
}
test_session = RuntimeMCPSession(
self.mcp_server_entity.name,
mixed_config,
enable=True,
ap=self.ap
)
await test_session.start()
# 获取工具列表作为测试
tools_count = len(test_session.functions)
tool_name_list = []
for function in test_session.functions:
tool_name_list.append(function.name)
task_context.set_current_action(f'Successfully connected. Found {tools_count} tools.')
# 关闭测试会话
await test_session.shutdown()
return {'status': 'success', 'tools_count': tools_count, 'tools_names_lists': tool_name_list}
except Exception as e:
self.ap.logger.error(f'Connection test failed: {str(e)}\n{traceback.format_exc()}')
task_context.set_current_action(f'Connection test failed: {str(e)}')
raise e
async def test_connection(self) -> str:
"""测试 MCP 服务器连接并返回任务 ID"""
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._test_mcp_server_task(task_context=ctx),
kind='mcp-operation',
name=f'mcp-test-{self.mcp_server_entity.name}',
label=f'Testing MCP server {self.mcp_server_entity.name}',
context=ctx,
)
return wrapper.id
async def dispose(self):
"""清理资源"""
if self.session:
await self.session.shutdown()
class MCPService:
@@ -13,6 +111,61 @@ class MCPService:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
def _convert_server_entity_to_config(
self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer]
) -> dict:
"""将数据库实体转换为 loader 需要的配置字典
Args:
server_entity: 数据库查询返回的服务器实体或 Row 对象
Returns:
包含服务器配置的字典
"""
if isinstance(server_entity, sqlalchemy.Row):
server = persistence_mcp.MCPServer(**server_entity._mapping)
else:
server = server_entity
return {
'name': server.name,
'mode': server.mode,
'enable': server.enable,
'extra_args': server.extra_args,
}
async def initialize(self) -> None:
"""初始化 MCP Service从数据库加载所有 MCP 服务器到运行时"""
self.ap.logger.info('Initializing MCP Service and loading servers from database...')
if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader:
self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization')
return
try:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
loaded_count = 0
failed_count = 0
for server in servers:
try:
# 将数据库实体转换为配置字典后传递给 loader
server_config = self._convert_server_entity_to_config(server)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
loaded_count += 1
self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}')
except Exception as e:
failed_count += 1
server_name = getattr(server, 'name', 'unknown')
self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}')
self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}')
except Exception as e:
self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}')
async def get_mcp_servers(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
@@ -22,11 +175,16 @@ class MCPService:
async def create_mcp_server(self, server_data: dict) -> str:
server_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data))
server = await self.get_mcp_server(server_data['uuid'])
# TODO: load runtime mcp server session
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid'])
)
server_entity = result.first()
if server_entity and self.ap.tool_mgr.mcp_tool_loader:
server_config = self._convert_server_entity_to_config(server_entity)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
return server['uuid']
return server_data['uuid']
async def get_mcp_server(self, server_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
@@ -47,17 +205,70 @@ class MCPService:
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
old_server = result.first()
old_server_name = old_server.name if old_server else None
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_mcp.MCPServer)
.where(persistence_mcp.MCPServer.uuid == server_uuid)
.values(server_data)
)
# TODO: reload runtime mcp server session
if self.ap.tool_mgr.mcp_tool_loader:
if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name)
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
updated_server = result.first()
if updated_server:
# convert entity to config dict
server_config = self._convert_server_entity_to_config(updated_server)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
async def delete_mcp_server(self, server_uuid: str) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
server_name = server.name if server else None
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
# TODO: remove runtime mcp server session
if server_name and self.ap.tool_mgr.mcp_tool_loader:
if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name)
async def test_mcp_server(self, server_uuid: str) -> str:
"""测试 MCP 服务器连接并返回任务 ID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
if server is None:
raise ValueError(f'Server not found: {server_uuid}')
if isinstance(server, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server._mapping)
else:
server_entity = server
runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity)
return await runtime_server.test_connection()

View File

@@ -129,6 +129,7 @@ class BuildAppStage(stage.BootingStage):
mcp_service_inst = mcp_service.MCPService(ap)
ap.mcp_service = mcp_service_inst
await mcp_service_inst.initialize()
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -10,9 +10,7 @@ from mcp.client.sse import sse_client
from .. import loader
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import sqlalchemy
class RuntimeMCPSession:
@@ -113,8 +111,14 @@ class RuntimeMCPSession:
)
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
"""关闭会话并清理资源"""
try:
if self.exit_stack:
await self.exit_stack.aclose()
self.functions.clear()
self.session = None
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
@loader.loader_class('mcp')
@@ -134,46 +138,48 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
await self.load_mcp_servers_from_db()
pass
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
try:
await self.load_mcp_server(server)
except Exception as e:
self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}')
async def init_runtime_mcp_session(self, server_config: dict):
"""从服务器配置创建运行时会话
async def init_runtime_mcp_session(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
if isinstance(server_entity, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server_entity._mapping)
elif isinstance(server_entity, dict):
server_entity = persistence_mcp.MCPServer(**server_entity)
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
name = server_config['name']
mode = server_config['mode']
enable = server_config['enable']
extra_args = server_config.get('extra_args', {})
mixed_config = {
'name': server_entity.name,
'mode': server_entity.mode,
'enable': server_entity.enable,
**server_entity.extra_args,
'name': name,
'mode': mode,
'enable': enable,
**extra_args,
}
session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap)
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
await session.initialize()
return session
async def load_mcp_server(
self,
server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict,
):
session = await self.init_runtime_mcp_session(server_entity)
async def load_mcp_server(self, server_config: dict):
"""加载 MCP 服务器到运行时
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
session = await self.init_runtime_mcp_session(server_config)
await session.start()
self.sessions[server_entity.name] = session
self.sessions[server_config['name']] = session
async def get_tools(self) -> list[resource_tool.LLMTool]:
all_functions = []
@@ -186,24 +192,91 @@ class MCPLoader(loader.ToolLoader):
return all_functions
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
"""检查工具是否存在"""
for session in self.sessions.values():
for function in session.functions:
if function.name == name:
return await function.func(**parameters)
return True
return False
async def invoke_tool(self, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
for session in self.sessions.values():
for function in session.functions:
if function.name == name:
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
try:
result = await function.func(**parameters)
self.ap.logger.debug(f'MCP tool {name} executed successfully')
return result
except Exception as e:
self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}')
raise
raise ValueError(f'Tool not found: {name}')
async def reload_mcp_server(self, server_config: dict):
"""重新加载 MCP 服务器(先移除再加载)
Args:
server_config: 服务器配置字典,必须包含 name 字段
"""
server_name = server_config['name']
if server_name in self.sessions:
await self.remove_mcp_server(server_name)
# 重新加载
await self.load_mcp_server(server_config)
async def remove_mcp_server(self, server_name: str):
"""移除 MCP 服务器"""
if server_name not in self.sessions:
raise ValueError(f'MCP server {server_name} not found')
self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal')
return
session = self.sessions.pop(server_name)
await session.shutdown()
self.ap.logger.info(f'Removed MCP server: {server_name}')
def get_session(self, server_name: str) -> RuntimeMCPSession | None:
"""获取指定名称的 MCP 会话"""
return self.sessions.get(server_name)
def has_session(self, server_name: str) -> bool:
"""检查是否存在指定名称的 MCP 会话"""
return server_name in self.sessions
def get_all_server_names(self) -> list[str]:
"""获取所有已加载的 MCP 服务器名称"""
return list(self.sessions.keys())
def get_server_tool_count(self, server_name: str) -> int:
"""获取指定服务器的工具数量"""
session = self.get_session(server_name)
return len(session.functions) if session else 0
def get_all_servers_info(self) -> dict[str, dict]:
"""获取所有服务器的信息"""
info = {}
for server_name, session in self.sessions.items():
info[server_name] = {
'name': server_name,
'mode': session.server_config.get('mode'),
'enable': session.enable,
'tools_count': len(session.functions),
'tool_names': [f.name for f in session.functions],
}
return info
async def shutdown(self):
"""关闭工具"""
for session in self.sessions.values():
await session.shutdown()
"""关闭所有工具"""
self.ap.logger.info('Shutting down all MCP sessions...')
for server_name, session in list(self.sessions.items()):
try:
await session.shutdown()
self.ap.logger.debug(f'Shutdown MCP session: {server_name}')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}')
self.sessions.clear()
self.ap.logger.info('All MCP sessions shutdown complete')