perf: mcp server testing and refreshing

This commit is contained in:
Junyan Qin
2025-11-04 18:14:59 +08:00
parent 1afecf01e4
commit 1046f3c2aa
10 changed files with 101 additions and 156 deletions

View File

@@ -57,10 +57,6 @@ class MCPRouterGroup(group.RouterGroup):
@self.route('/servers/<server_name>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(server_name: str) -> str:
"""测试MCP服务器连接"""
server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name)
if server_data is None:
return self.http_status(404, -1, 'Server not found')
task_id = await self.ap.mcp_service.test_mcp_server(server_data['uuid'])
server_data = await quart.request.json
task_id = await self.ap.mcp_service.test_mcp_server(server_name=server_name, server_data=server_data)
return self.success(data={'task_id': task_id})

View File

@@ -2,98 +2,12 @@ from __future__ import annotations
import sqlalchemy
import uuid
import traceback
import asyncio
from ....core import app
from ....entity.persistence import mcp as persistence_mcp
from ....core import taskmgr
from ....provider.tools.loaders.mcp import RuntimeMCPSession
class RuntimeMCPServer:
"""Runtime MCP Server representation"""
ap: app.Application
mcp_server_entity: persistence_mcp.MCPServer
session: RuntimeMCPSession | None = None
def __init__(self, ap: app.Application, mcp_server_entity: persistence_mcp.MCPServer):
self.ap = ap
self.mcp_server_entity = mcp_server_entity
self.session = None
async def initialize(self):
"""初始化 MCP Server"""
if not self.mcp_server_entity.enable:
return
# 构建配置字典
mixed_config = {
'name': self.mcp_server_entity.name,
'mode': self.mcp_server_entity.mode,
'enable': self.mcp_server_entity.enable,
**self.mcp_server_entity.extra_args,
}
self.session = RuntimeMCPSession(
self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap
)
await self.session.start()
async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext):
"""测试MCP服务器连接"""
try:
task_context.set_current_action(f'Testing connection to {self.mcp_server_entity.name}')
# 创建临时会话进行测试
mixed_config = {
'name': self.mcp_server_entity.name,
'mode': self.mcp_server_entity.mode,
'enable': True, # 测试时强制启用
**self.mcp_server_entity.extra_args,
}
test_session = RuntimeMCPSession(self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap)
await test_session.start()
# 获取工具列表作为测试
tools_count = len(test_session.functions)
tool_name_list = []
for function in test_session.functions:
tool_name_list.append(function.name)
task_context.set_current_action(f'Successfully connected. Found {tools_count} tools.')
# 关闭测试会话
await test_session.shutdown()
return {'status': 'success', 'tools_count': tools_count, 'tools_names_lists': tool_name_list}
except Exception as e:
self.ap.logger.error(f'Connection test failed: {str(e)}\n{traceback.format_exc()}')
task_context.set_current_action(f'Connection test failed: {str(e)}')
raise e
async def test_connection(self) -> str:
"""测试 MCP 服务器连接并返回任务 ID"""
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._test_mcp_server_task(task_context=ctx),
kind='mcp-operation',
name=f'mcp-test-{self.mcp_server_entity.name}',
label=f'Testing MCP server {self.mcp_server_entity.name}',
context=ctx,
)
return wrapper.id
async def dispose(self):
"""清理资源"""
if self.session:
await self.session.shutdown()
from ....provider.tools.loaders.mcp import RuntimeMCPSession, MCPSessionStatus
class MCPService:
@@ -176,7 +90,6 @@ class MCPService:
if updated_server:
# convert entity to config dict
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server)
# await self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config)
task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config))
self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task)
@@ -195,21 +108,30 @@ class MCPService:
if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions:
await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name)
async def test_mcp_server(self, server_uuid: str) -> str:
async def test_mcp_server(self, server_name: str, server_data: dict) -> int:
"""测试 MCP 服务器连接并返回任务 ID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid)
)
server = result.first()
if server is None:
raise ValueError(f'Server not found: {server_uuid}')
runtime_mcp_session: RuntimeMCPSession | None = None
if isinstance(server, sqlalchemy.Row):
server_entity = persistence_mcp.MCPServer(**server._mapping)
if server_name != '_':
runtime_mcp_session = self.ap.tool_mgr.mcp_tool_loader.get_session(server_name)
if runtime_mcp_session is None:
raise ValueError(f'Server not found: {server_name}')
if runtime_mcp_session.status == MCPSessionStatus.ERROR:
coroutine = runtime_mcp_session.start()
else:
coroutine = runtime_mcp_session.refresh()
else:
server_entity = server
runtime_mcp_session = await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config=server_data)
coroutine = runtime_mcp_session.start()
runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity)
return await runtime_server.test_connection()
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
coroutine,
kind='mcp-operation',
name=f'mcp-test-{server_name}',
label=f'Testing MCP server {server_name}',
context=ctx,
)
return wrapper.id