perf: make startup async

This commit is contained in:
Junyan Qin
2025-11-03 20:16:45 +08:00
parent 4d0a28a1a7
commit f3199dda20
3 changed files with 65 additions and 108 deletions

View File

@@ -38,10 +38,7 @@ class RuntimeMCPServer:
}
self.session = RuntimeMCPSession(
self.mcp_server_entity.name,
mixed_config,
self.mcp_server_entity.enable,
self.ap
self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap
)
await self.session.initialize()
await self.session.start()
@@ -59,12 +56,7 @@ class RuntimeMCPServer:
**self.mcp_server_entity.extra_args,
}
test_session = RuntimeMCPSession(
self.mcp_server_entity.name,
mixed_config,
enable=True,
ap=self.ap
)
test_session = RuntimeMCPSession(self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap)
await test_session.start()
# 获取工具列表作为测试
@@ -104,68 +96,12 @@ class RuntimeMCPServer:
await self.session.shutdown()
class MCPService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
def _convert_server_entity_to_config(
self, server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer]
) -> dict:
"""将数据库实体转换为 loader 需要的配置字典
Args:
server_entity: 数据库查询返回的服务器实体或 Row 对象
Returns:
包含服务器配置的字典
"""
if isinstance(server_entity, sqlalchemy.Row):
server = persistence_mcp.MCPServer(**server_entity._mapping)
else:
server = server_entity
return {
'name': server.name,
'mode': server.mode,
'enable': server.enable,
'extra_args': server.extra_args,
}
async def initialize(self) -> None:
"""初始化 MCP Service从数据库加载所有 MCP 服务器到运行时"""
self.ap.logger.info('Initializing MCP Service and loading servers from database...')
if not self.ap.tool_mgr or not self.ap.tool_mgr.mcp_tool_loader:
self.ap.logger.warning('MCP tool loader not available, skipping MCP servers initialization')
return
try:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
loaded_count = 0
failed_count = 0
for server in servers:
try:
# 将数据库实体转换为配置字典后传递给 loader
server_config = self._convert_server_entity_to_config(server)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
loaded_count += 1
self.ap.logger.debug(f'Loaded MCP server: {server_config["name"]}')
except Exception as e:
failed_count += 1
server_name = getattr(server, 'name', 'unknown')
self.ap.logger.error(f'Failed to load MCP server {server_name}: {e}\n{traceback.format_exc()}')
self.ap.logger.info(f'MCP Service initialization complete. Loaded: {loaded_count}, Failed: {failed_count}')
except Exception as e:
self.ap.logger.error(f'Failed to initialize MCP Service: {e}\n{traceback.format_exc()}')
async def get_mcp_servers(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
@@ -180,9 +116,10 @@ class MCPService:
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_data['uuid'])
)
server_entity = result.first()
if server_entity and self.ap.tool_mgr.mcp_tool_loader:
server_config = self._convert_server_entity_to_config(server_entity)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
if server_entity:
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server_entity)
if self.ap.tool_mgr.mcp_tool_loader:
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
return server_data['uuid']
@@ -205,14 +142,12 @@ class MCPService:
return self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
async def update_mcp_server(self, server_uuid: str, server_data: dict) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
old_server = result.first()
old_server_name = old_server.name if old_server else None
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_mcp.MCPServer)
.where(persistence_mcp.MCPServer.uuid == server_uuid)
@@ -220,41 +155,36 @@ class MCPService:
)
if self.ap.tool_mgr.mcp_tool_loader:
if old_server_name and old_server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(old_server_name)
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
updated_server = result.first()
if updated_server:
# convert entity to config dict
server_config = self._convert_server_entity_to_config(updated_server)
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server)
await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config)
async def delete_mcp_server(self, server_uuid: str) -> None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
server_name = server.name if server else None
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
if server_name and self.ap.tool_mgr.mcp_tool_loader:
if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name)
async def test_mcp_server(self, server_uuid: str) -> str:
"""测试 MCP 服务器连接并返回任务 ID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
@@ -262,7 +192,6 @@ class MCPService:
if server is None:
raise ValueError(f'Server not found: {server_uuid}')
if isinstance(server, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server._mapping)
else:
@@ -270,5 +199,4 @@ class MCPService:
runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity)
return await runtime_server.test_connection()