diff --git a/pkg/api/http/controller/groups/resources/mcp.py b/pkg/api/http/controller/groups/resources/mcp.py index 46559bc9..ac91abff 100644 --- a/pkg/api/http/controller/groups/resources/mcp.py +++ b/pkg/api/http/controller/groups/resources/mcp.py @@ -57,10 +57,6 @@ class MCPRouterGroup(group.RouterGroup): @self.route('/servers//test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _(server_name: str) -> str: """测试MCP服务器连接""" - - server_data = await self.ap.mcp_service.get_mcp_server_by_name(server_name) - if server_data is None: - return self.http_status(404, -1, 'Server not found') - - task_id = await self.ap.mcp_service.test_mcp_server(server_data['uuid']) + server_data = await quart.request.json + task_id = await self.ap.mcp_service.test_mcp_server(server_name=server_name, server_data=server_data) return self.success(data={'task_id': task_id}) diff --git a/pkg/api/http/service/mcp.py b/pkg/api/http/service/mcp.py index 4ce6e5c2..3766e7d6 100644 --- a/pkg/api/http/service/mcp.py +++ b/pkg/api/http/service/mcp.py @@ -2,98 +2,12 @@ from __future__ import annotations import sqlalchemy import uuid -import traceback import asyncio from ....core import app from ....entity.persistence import mcp as persistence_mcp from ....core import taskmgr -from ....provider.tools.loaders.mcp import RuntimeMCPSession - - -class RuntimeMCPServer: - """Runtime MCP Server representation""" - - ap: app.Application - - mcp_server_entity: persistence_mcp.MCPServer - - session: RuntimeMCPSession | None = None - - def __init__(self, ap: app.Application, mcp_server_entity: persistence_mcp.MCPServer): - self.ap = ap - self.mcp_server_entity = mcp_server_entity - self.session = None - - async def initialize(self): - """初始化 MCP Server""" - if not self.mcp_server_entity.enable: - return - - # 构建配置字典 - mixed_config = { - 'name': self.mcp_server_entity.name, - 'mode': self.mcp_server_entity.mode, - 'enable': self.mcp_server_entity.enable, - **self.mcp_server_entity.extra_args, - } - - self.session = RuntimeMCPSession( - self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap - ) - await self.session.start() - - async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext): - """测试MCP服务器连接""" - try: - task_context.set_current_action(f'Testing connection to {self.mcp_server_entity.name}') - - # 创建临时会话进行测试 - mixed_config = { - 'name': self.mcp_server_entity.name, - 'mode': self.mcp_server_entity.mode, - 'enable': True, # 测试时强制启用 - **self.mcp_server_entity.extra_args, - } - - test_session = RuntimeMCPSession(self.mcp_server_entity.name, mixed_config, enable=True, ap=self.ap) - await test_session.start() - - # 获取工具列表作为测试 - tools_count = len(test_session.functions) - - tool_name_list = [] - for function in test_session.functions: - tool_name_list.append(function.name) - - task_context.set_current_action(f'Successfully connected. Found {tools_count} tools.') - - # 关闭测试会话 - await test_session.shutdown() - - return {'status': 'success', 'tools_count': tools_count, 'tools_names_lists': tool_name_list} - - except Exception as e: - self.ap.logger.error(f'Connection test failed: {str(e)}\n{traceback.format_exc()}') - task_context.set_current_action(f'Connection test failed: {str(e)}') - raise e - - async def test_connection(self) -> str: - """测试 MCP 服务器连接并返回任务 ID""" - ctx = taskmgr.TaskContext.new() - wrapper = self.ap.task_mgr.create_user_task( - self._test_mcp_server_task(task_context=ctx), - kind='mcp-operation', - name=f'mcp-test-{self.mcp_server_entity.name}', - label=f'Testing MCP server {self.mcp_server_entity.name}', - context=ctx, - ) - return wrapper.id - - async def dispose(self): - """清理资源""" - if self.session: - await self.session.shutdown() +from ....provider.tools.loaders.mcp import RuntimeMCPSession, MCPSessionStatus class MCPService: @@ -176,7 +90,6 @@ class MCPService: if updated_server: # convert entity to config dict server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, updated_server) - # await self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config) task = asyncio.create_task(self.ap.tool_mgr.mcp_tool_loader.host_mcp_server(server_config)) self.ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks.append(task) @@ -195,21 +108,30 @@ class MCPService: if server_name in self.ap.tool_mgr.mcp_tool_loader.sessions: await self.ap.tool_mgr.mcp_tool_loader.remove_mcp_server(server_name) - async def test_mcp_server(self, server_uuid: str) -> str: + async def test_mcp_server(self, server_name: str, server_data: dict) -> int: """测试 MCP 服务器连接并返回任务 ID""" - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_mcp.MCPServer).where(persistence_mcp.MCPServer.uuid == server_uuid) - ) - server = result.first() - if server is None: - raise ValueError(f'Server not found: {server_uuid}') + runtime_mcp_session: RuntimeMCPSession | None = None - if isinstance(server, sqlalchemy.Row): - server_entity = persistence_mcp.MCPServer(**server._mapping) + if server_name != '_': + runtime_mcp_session = self.ap.tool_mgr.mcp_tool_loader.get_session(server_name) + if runtime_mcp_session is None: + raise ValueError(f'Server not found: {server_name}') + + if runtime_mcp_session.status == MCPSessionStatus.ERROR: + coroutine = runtime_mcp_session.start() + else: + coroutine = runtime_mcp_session.refresh() else: - server_entity = server + runtime_mcp_session = await self.ap.tool_mgr.mcp_tool_loader.load_mcp_server(server_config=server_data) + coroutine = runtime_mcp_session.start() - runtime_server = RuntimeMCPServer(ap=self.ap, mcp_server_entity=server_entity) - - return await runtime_server.test_connection() + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + coroutine, + kind='mcp-operation', + name=f'mcp-test-{server_name}', + label=f'Testing MCP server {server_name}', + context=ctx, + ) + return wrapper.id diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 531f75f6..edff9e01 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -101,29 +101,7 @@ class RuntimeMCPSession: else: raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}') - tools = await self.session.list_tools() - - self.ap.logger.debug(f'获取 MCP 工具: {tools}') - - for tool in tools.tools: - - async def func(*, _tool=tool, **kwargs): - result = await self.session.call_tool(_tool.name, kwargs) - if result.isError: - raise Exception(result.content[0].text) - return result.content[0].text - - func.__name__ = tool.name - - self.functions.append( - resource_tool.LLMTool( - name=tool.name, - human_desc=tool.description, - description=tool.description, - parameters=tool.inputSchema, - func=func, - ) - ) + await self.refresh() self.status = MCPSessionStatus.CONNECTED self.last_test_error_message = '' @@ -132,6 +110,33 @@ class RuntimeMCPSession: self.last_test_error_message = str(e) raise e + async def refresh(self): + self.functions.clear() + + tools = await self.session.list_tools() + + self.ap.logger.debug(f'Refresh MCP tools: {tools}') + + for tool in tools.tools: + + async def func(*, _tool=tool, **kwargs): + result = await self.session.call_tool(_tool.name, kwargs) + if result.isError: + raise Exception(result.content[0].text) + return result.content[0].text + + func.__name__ = tool.name + + self.functions.append( + resource_tool.LLMTool( + name=tool.name, + human_desc=tool.description, + description=tool.description, + parameters=tool.inputSchema, + func=func, + ) + ) + def get_tools(self) -> list[resource_tool.LLMTool]: return self.functions diff --git a/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx b/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx index 525e0081..fd19cd4b 100644 --- a/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx +++ b/web/src/app/home/plugins/mcp-server/mcp-card/MCPCardComponent.tsx @@ -51,7 +51,7 @@ export default function MCPCardComponent({ setTesting(true); httpClient - .testMCPServer(cardVO.name) + .testMCPServer(cardVO.name, {}) .then((resp) => { const taskId = resp.task_id; @@ -62,9 +62,11 @@ export default function MCPCardComponent({ setTesting(false); if (taskResp.runtime.exception) { - toast.error(t('mcp.testFailed') + taskResp.runtime.exception); + toast.error( + t('mcp.refreshFailed') + taskResp.runtime.exception, + ); } else { - toast.success(t('mcp.testSuccess')); + toast.success(t('mcp.refreshSuccess')); } // Refresh to get updated runtime_info @@ -74,7 +76,7 @@ export default function MCPCardComponent({ }, 1000); }) .catch((err) => { - toast.error(t('mcp.testFailed') + err.message); + toast.error(t('mcp.refreshFailed') + err.message); setTesting(false); }); } diff --git a/web/src/app/home/plugins/mcp-server/mcp-form/MCPFormDialog.tsx b/web/src/app/home/plugins/mcp-server/mcp-form/MCPFormDialog.tsx index 717a24a8..4638bd1e 100644 --- a/web/src/app/home/plugins/mcp-server/mcp-form/MCPFormDialog.tsx +++ b/web/src/app/home/plugins/mcp-server/mcp-form/MCPFormDialog.tsx @@ -361,11 +361,22 @@ export default function MCPFormDialog({ } async function testMcp() { - const serverName = form.getValues('name'); setMcpTesting(true); try { - const { task_id } = await httpClient.testMCPServer(serverName); + const { task_id } = await httpClient.testMCPServer('_', { + name: form.getValues('name'), + mode: 'sse', + enable: true, + extra_args: { + url: form.getValues('url'), + timeout: form.getValues('timeout'), + ssereadtimeout: form.getValues('ssereadtimeout'), + headers: Object.fromEntries( + extraArgs.map((arg) => [arg.key, arg.value]), + ), + }, + }); if (!task_id) { throw new Error(t('mcp.noTaskId')); } @@ -388,13 +399,11 @@ export default function MCPFormDialog({ tool_count: 0, tools: [], }); - } else if (taskResp.runtime.result) { - await loadServerForEdit(serverName); - toast.success(t('mcp.testSuccess')); } else { - toast.error( - `${t('mcp.testError')}: ${t('mcp.noResultReturned')}`, - ); + if (isEditMode) { + await loadServerForEdit(form.getValues('name')); + } + toast.success(t('mcp.testSuccess')); } } } catch (err) { diff --git a/web/src/app/infra/http/BackendClient.ts b/web/src/app/infra/http/BackendClient.ts index 3d0c3ee3..cc47d3fa 100644 --- a/web/src/app/infra/http/BackendClient.ts +++ b/web/src/app/infra/http/BackendClient.ts @@ -524,8 +524,11 @@ export class BackendClient extends BaseHttpClient { }); } - public testMCPServer(serverName: string): Promise { - return this.post(`/api/v1/mcp/servers/${serverName}/test`); + public testMCPServer( + serverName: string, + serverData: object, + ): Promise { + return this.post(`/api/v1/mcp/servers/${serverName}/test`, serverData); } public installMCPServerFromGithub( diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 71b70cdf..9b76fd86 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -312,9 +312,11 @@ const enUS = { value: 'Value', testing: 'Testing...', connecting: 'Connecting...', - testSuccess: 'Connection test successful', - testFailed: 'Connection test failed: ', - testError: 'Connection test error', + testSuccess: 'Test successful', + testFailed: 'Test failed: ', + testError: 'Test error', + refreshSuccess: 'Refresh successful', + refreshFailed: 'Refresh failed: ', connectionSuccess: 'Connection successful', connectionFailed: 'Connection failed', toolsFound: 'tools', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 904d9de8..487aa202 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -314,9 +314,11 @@ const jaJP = { value: '値', testing: 'テスト中...', connecting: '接続中...', - testSuccess: '接続テストに成功しました', - testFailed: '接続テストに失敗しました:', - testError: '接続テストエラー', + testSuccess: '刷新に成功しました', + testFailed: '刷新に失敗しました:', + testError: '刷新エラー', + refreshSuccess: '刷新に成功しました', + refreshFailed: '刷新に失敗しました:', connectionSuccess: '接続に成功しました', connectionFailed: '接続に失敗しました', toolsFound: '個のツール', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index d40bdfb9..20fd1f2b 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -300,9 +300,11 @@ const zhHans = { value: '值', testing: '测试中...', connecting: '连接中...', - testSuccess: '连接测试成功', - testFailed: '连接测试失败:', - testError: '连接测试出错', + testSuccess: '测试成功', + testFailed: '测试失败:', + testError: '刷新出错', + refreshSuccess: '刷新成功', + refreshFailed: '刷新失败:', connectionSuccess: '连接成功', connectionFailed: '连接失败', toolsFound: '个工具', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 581cdb6c..478433b2 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -298,9 +298,11 @@ const zhHant = { value: '值', testing: '測試中...', connecting: '連接中...', - testSuccess: '連接測試成功', - testFailed: '連接測試失敗:', - testError: '連接測試出錯', + testSuccess: '測試成功', + testFailed: '刷新失敗:', + testError: '刷新出錯', + refreshSuccess: '刷新成功', + refreshFailed: '刷新失敗:', connectionSuccess: '連接成功', connectionFailed: '連接失敗', toolsFound: '個工具',