diff --git a/pkg/api/http/controller/groups/resources/mcp.py b/pkg/api/http/controller/groups/resources/mcp.py new file mode 100644 index 00000000..f444639c --- /dev/null +++ b/pkg/api/http/controller/groups/resources/mcp.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import quart +import asyncio + +from ......core import taskmgr +from ... import group + + +@group.group_class('mcp', '/api/v1/mcp') +class MCPRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('/servers', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + """获取MCP服务器列表""" + if quart.request.method == 'GET': + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.success(data={'servers': []}) + + servers = self.ap.provider_cfg.data.get('mcp', {}).get('servers', []) + + # 获取每个服务器的状态和工具信息 + mcp_loader = None + for loader_name, loader in self.ap.tool_mgr.loaders.items(): + if loader_name == 'mcp': + mcp_loader = loader + break + + servers_with_status = [] + for server in servers: + server_info = { + 'name': server['name'], + 'mode': server['mode'], + 'enable': server['enable'], + 'config': server, + 'status': 'disconnected', + 'tools': [], + 'error': None, + } + + # 检查服务器连接状态 + if mcp_loader and server['name'] in mcp_loader.sessions: + session = mcp_loader.sessions[server['name']] + server_info['status'] = 'connected' + server_info['tools'] = [ + {'name': func.name, 'description': func.description, 'parameters': func.parameters} + for func in session.functions + ] + elif server['enable']: + server_info['status'] = 'error' + server_info['error'] = 'Failed to connect' + + servers_with_status.append(server_info) + + return self.success(data={'servers': servers_with_status}) + elif quart.request.method == 'POST': + data = await quart.request.json + + # 验证必填字段 + required_fields = ['name', 'mode'] + for field in required_fields: + if field not in data: + return self.http_status(400, -1, f'Missing required field: {field}') + + # 检查provider_cfg是否可用 + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.http_status(500, -1, 'Provider configuration not available') + + # 获取当前配置 + mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) + servers = mcp_config['servers'] + + # 检查服务器名称是否重复 + for server in servers: + if server['name'] == data['name']: + return self.http_status(400, -1, 'Server name already exists') + + # 创建新服务器配置 + new_server = { + 'name': data['name'], + 'mode': data['mode'], + 'enable': data.get('enable', True), + } + + # 根据模式添加配置 + if data['mode'] == 'stdio': + new_server.update( + {'command': data.get('command', ''), 'args': data.get('args', []), 'env': data.get('env', {})} + ) + elif data['mode'] == 'sse': + new_server.update( + { + 'url': data.get('url', ''), + 'headers': data.get('headers', {}), + 'timeout': data.get('timeout', 10), + } + ) + + # 添加到配置 + servers.append(new_server) + self.ap.provider_cfg.data['mcp'] = mcp_config + + # 保存配置 + await self.ap.provider_cfg.dump_config() + + # 如果启用,尝试重新加载MCP loader + if new_server['enable']: + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._reload_mcp_loader(ctx), + kind='mcp-operation', + name=f'mcp-reload-{new_server["name"]}', + label=f'Reloading MCP loader for {new_server["name"]}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + else: + return self.success() + else: + return self.http_status(405, -1, 'Method not allowed') + + @self.route('/servers/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN) + async def _(server_name: str) -> str: + """获取、更新或删除MCP服务器配置""" + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.http_status(500, -1, 'Provider configuration not available') + + mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) + servers = mcp_config['servers'] + + # 查找服务器 + server_index = None + for i, server in enumerate(servers): + if server['name'] == server_name: + server_index = i + break + + if server_index is None: + return self.http_status(404, -1, 'Server not found') + + if quart.request.method == 'GET': + return self.success(data={'server': servers[server_index]}) + + elif quart.request.method == 'PUT': + data = await quart.request.json + server = servers[server_index] + + # 更新配置 + server.update( + { + 'enable': data.get('enable', server.get('enable', True)), + } + ) + + # 根据模式更新特定配置 + if server['mode'] == 'stdio': + server.update( + { + 'command': data.get('command', server.get('command', '')), + 'args': data.get('args', server.get('args', [])), + 'env': data.get('env', server.get('env', {})), + } + ) + elif server['mode'] == 'sse': + server.update( + { + 'url': data.get('url', server.get('url', '')), + 'headers': data.get('headers', server.get('headers', {})), + 'timeout': data.get('timeout', server.get('timeout', 10)), + } + ) + + # 保存配置 + await self.ap.provider_cfg.dump_config() + + # 重新加载MCP loader + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._reload_mcp_loader(ctx), + kind='mcp-operation', + name=f'mcp-reload-{server_name}', + label=f'Reloading MCP loader for {server_name}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + + elif quart.request.method == 'DELETE': + # 删除服务器 + servers.pop(server_index) + self.ap.provider_cfg.data['mcp'] = mcp_config + + # 保存配置 + await self.ap.provider_cfg.dump_config() + + # 重新加载MCP loader + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._reload_mcp_loader(ctx), + kind='mcp-operation', + name=f'mcp-remove-{server_name}', + label=f'Removing MCP server {server_name}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + + @self.route('/servers//toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN) + async def _(server_name: str) -> str: + """切换MCP服务器启用状态""" + data = await quart.request.json + target_enabled = data.get('target_enabled') + + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.http_status(500, -1, 'Provider configuration not available') + + mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) + servers = mcp_config['servers'] + + # 查找并更新服务器 + for server in servers: + if server['name'] == server_name: + server['enable'] = target_enabled + break + else: + return self.http_status(404, -1, 'Server not found') + + # 保存配置 + await self.ap.provider_cfg.dump_config() + + # 重新加载MCP loader + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._reload_mcp_loader(ctx), + kind='mcp-operation', + name=f'mcp-toggle-{server_name}', + label=f'Toggling MCP server {server_name}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + + @self.route('/servers//test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _(server_name: str) -> str: + """测试MCP服务器连接""" + if not self.ap or not self.ap.provider_cfg or not self.ap.provider_cfg.data: + return self.http_status(500, -1, 'Provider configuration not available') + + mcp_config = self.ap.provider_cfg.data.get('mcp', {'servers': []}) + servers = mcp_config['servers'] + + # 查找服务器配置 + server_config = None + for server in servers: + if server['name'] == server_name: + server_config = server + break + + if server_config is None: + return self.http_status(404, -1, 'Server not found') + + # 创建测试任务 + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._test_mcp_server(server_config, ctx), + kind='mcp-operation', + name=f'mcp-test-{server_name}', + label=f'Testing MCP server {server_name}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + + @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> str: + """从GitHub安装MCP服务器""" + data = await quart.request.json + source = data.get('source') + + if not source: + return self.http_status(400, -1, 'Missing source parameter') + + # 创建安装任务 + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._install_mcp_from_github(source, ctx), + kind='mcp-operation', + name='install-mcp-github', + label=f'Installing MCP from GitHub: {source}', + context=ctx, + ) + return self.success(data={'task_id': wrapper.id}) + + async def _reload_mcp_loader(self, ctx: taskmgr.TaskContext): + """重新加载MCP loader""" + try: + ctx.current_action = 'Stopping existing MCP sessions' + # 停止现有的MCP会话 + mcp_loader = None + for loader_name, loader in self.ap.tool_mgr.loaders.items(): + if loader_name == 'mcp': + mcp_loader = loader + break + + if mcp_loader: + await mcp_loader.shutdown() + + ctx.current_action = 'Reloading MCP configuration' + # 重新加载MCP loader + await self.ap.tool_mgr.reload_loader('mcp') + + ctx.current_action = 'MCP loader reloaded successfully' + + except Exception as e: + ctx.current_action = f'Failed to reload MCP loader: {str(e)}' + raise e + + async def _test_mcp_server(self, server_config: dict, ctx: taskmgr.TaskContext): + """测试MCP服务器连接""" + try: + from ......provider.tools.loaders.mcp import RuntimeMCPSession + + ctx.current_action = f'Testing connection to {server_config["name"]}' + + # 创建临时会话进行测试 + session = RuntimeMCPSession(server_config['name'], server_config, self.ap) + await session.initialize() + + # 获取工具列表作为测试 + tools_count = len(session.functions) + ctx.current_action = f'Successfully connected. Found {tools_count} tools.' + + # 关闭测试会话 + await session.shutdown() + + return {'status': 'success', 'tools_count': tools_count} + + except Exception as e: + ctx.current_action = f'Connection test failed: {str(e)}' + raise e + + async def _install_mcp_from_github(self, source: str, ctx: taskmgr.TaskContext): + """从GitHub安装MCP服务器的实现""" + try: + ctx.current_action = f'Installing MCP server from {source}' + + # 这里是安装逻辑的占位符 + # 实际实现将包括克隆仓库、解析配置、安装依赖等步骤 + + # 模拟安装过程 + + await asyncio.sleep(2) # 模拟安装过程 + + # 返回成功结果 + return {'status': 'success', 'message': f'Successfully installed MCP server from {source}'} + + except Exception as e: + ctx.current_action = f'Failed to install MCP server: {str(e)}' + raise e diff --git a/pkg/entity/persistence/mcp.py b/pkg/entity/persistence/mcp.py new file mode 100644 index 00000000..74478dc7 --- /dev/null +++ b/pkg/entity/persistence/mcp.py @@ -0,0 +1,20 @@ +import sqlalchemy + +from .base import Base + + +class MCPServer(Base): + __tablename__ = 'mcp_servers' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) + mode = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) # stdio, sse + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index d649b41e..f0bec0a5 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -59,7 +59,7 @@ class ModelManager: try: await self.load_llm_model(llm_model) except provider_errors.RequesterNotFoundError as e: - self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping model {llm_model.uuid}') + self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}') except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') @@ -67,7 +67,14 @@ class ModelManager: result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) embedding_models = result.all() for embedding_model in embedding_models: - await self.load_embedding_model(embedding_model) + try: + await self.load_embedding_model(embedding_model) + except provider_errors.RequesterNotFoundError as e: + self.ap.logger.warning( + f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}' + ) + except Exception as e: + self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}') async def init_runtime_llm_model( self, @@ -107,6 +114,9 @@ class ModelManager: elif isinstance(model_info, dict): model_info = persistence_model.EmbeddingModel(**model_info) + if model_info.requester not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(model_info.requester) + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) await requester_inst.initialize() diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 36fa9751..8677f41c 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -2,6 +2,7 @@ from __future__ import annotations import typing from contextlib import AsyncExitStack +import traceback from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -9,7 +10,9 @@ from mcp.client.sse import sse_client from .. import loader from ....core import app +from ....entity.persistence import mcp as persistence_mcp import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import sqlalchemy class RuntimeMCPSession: @@ -27,11 +30,13 @@ class RuntimeMCPSession: functions: list[resource_tool.LLMTool] = [] - def __init__(self, server_name: str, server_config: dict, ap: app.Application): + enable: bool + + def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application): self.server_name = server_name self.server_config = server_config self.ap = ap - + self.enable = enable self.session = None self.exit_stack = AsyncExitStack() @@ -68,6 +73,12 @@ class RuntimeMCPSession: await self.session.initialize() async def initialize(self): + pass + + async def start(self): + if not self.enable: + return + self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}') if self.server_config['mode'] == 'stdio': @@ -123,13 +134,45 @@ class MCPLoader(loader.ToolLoader): self._last_listed_functions = [] async def initialize(self): - for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []): - if not server_config['enable']: - continue - session = RuntimeMCPSession(server_config['name'], server_config, self.ap) - await session.initialize() - # self.ap.event_loop.create_task(session.initialize()) - self.sessions[server_config['name']] = session + await self.load_mcp_servers_from_db() + + async def load_mcp_servers_from_db(self): + self.ap.logger.info('Loading MCP servers from db...') + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer)) + servers = result.all() + for server in servers: + try: + await self.load_mcp_server(server) + except Exception as e: + self.ap.logger.error(f'Failed to load MCP server {server.name}: {e}\n{traceback.format_exc()}') + + async def init_runtime_mcp_session( + self, + server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict, + ): + if isinstance(server_entity, sqlalchemy.Row): + server_entity = persistence_mcp.MCPServer(**server_entity._mapping) + elif isinstance(server_entity, dict): + server_entity = persistence_mcp.MCPServer(**server_entity) + + mixed_config = { + 'name': server_entity.name, + 'mode': server_entity.mode, + 'enable': server_entity.enable, + **server_entity.extra_args, + } + + session = RuntimeMCPSession(server_entity.name, mixed_config, server_entity.enable, self.ap) + await session.initialize() + + return session + + async def load_mcp_server( + self, + server_entity: persistence_mcp.MCPServer | sqlalchemy.Row[persistence_mcp.MCPServer] | dict, + ): + session = await self.init_runtime_mcp_session(server_entity) + self.sessions[server_entity.name] = session async def get_tools(self) -> list[resource_tool.LLMTool]: all_functions = [] @@ -150,7 +193,14 @@ class MCPLoader(loader.ToolLoader): if function.name == name: return await function.func(**parameters) - raise ValueError(f'未找到工具: {name}') + raise ValueError(f'Tool not found: {name}') + + async def remove_mcp_server(self, server_name: str): + if server_name not in self.sessions: + raise ValueError(f'MCP server {server_name} not found') + + session = self.sessions.pop(server_name) + await session.shutdown() async def shutdown(self): """关闭工具"""