diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 805b50f7..c1db6482 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -34,6 +34,7 @@ required_deps = { "dashscope": "dashscope", "telegram": "python-telegram-bot", "certifi": "certifi", + "mcp": "mcp", } diff --git a/pkg/core/migrations/m037_mcp_config.py b/pkg/core/migrations/m037_mcp_config.py new file mode 100644 index 00000000..f045f0ff --- /dev/null +++ b/pkg/core/migrations/m037_mcp_config.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("mcp-config", 37) +class MCPConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'mcp' not in self.ap.provider_cfg.data + + async def run(self): + """执行迁移""" + self.ap.provider_cfg.data['mcp'] = { + "servers": [] + } + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index e528e164..fe0dc464 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -12,6 +12,8 @@ from ..migrations import m020_wecom_config, m021_lark_config, m022_lmstudio_conf from ..migrations import m026_qqofficial_config, m027_wx_official_account_config, m028_aliyun_requester_config from ..migrations import m029_dashscope_app_api_config, m030_lark_config_cmpl, m031_dingtalk_config, m032_volcark_config from ..migrations import m033_dify_thinking_config, m034_gewechat_file_url_config, m035_wxoa_mode, m036_wxoa_loading_message +from ..migrations import m037_mcp_config + @stage.stage_class("MigrationStage") class MigrationStage(stage.BootingStage): diff --git a/pkg/provider/tools/loaders/__init__.py b/pkg/provider/tools/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py new file mode 100644 index 00000000..4d15bf60 --- /dev/null +++ b/pkg/provider/tools/loaders/mcp.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import typing +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.client.sse import sse_client + +from .. import loader, entities as tools_entities +from ....core import app, entities as core_entities + + +class RuntimeMCPSession: + """运行时 MCP 会话""" + + ap: app.Application + + server_name: str + + server_config: dict + + session: ClientSession + + exit_stack: AsyncExitStack + + functions: list[tools_entities.LLMFunction] = [] + + def __init__(self, server_name: str, server_config: dict, ap: app.Application): + self.server_name = server_name + self.server_config = server_config + self.ap = ap + + self.exit_stack = AsyncExitStack() + + async def _init_stdio_python_server(self): + server_params = StdioServerParameters( + command=self.server_config["command"], + args=self.server_config["args"], + env=self.server_config["env"], + ) + + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params) + ) + + stdio, write = stdio_transport + + self.session = await self.exit_stack.enter_async_context( + ClientSession(stdio, write) + ) + + await self.session.initialize() + + async def _init_sse_server(self): + sse_transport = await self.exit_stack.enter_async_context( + sse_client( + self.server_config["url"], + headers=self.server_config.get("headers", {}), + timeout=self.server_config.get("timeout", 10), + ) + ) + + sseio, write = sse_transport + + self.session = await self.exit_stack.enter_async_context( + ClientSession(sseio, write) + ) + + await self.session.initialize() + + async def initialize(self): + self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}") + + if self.server_config["mode"] == "stdio": + await self._init_stdio_python_server() + elif self.server_config["mode"] == "sse": + await self._init_sse_server() + else: + raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}") + + tools = await self.session.list_tools() + + self.ap.logger.debug(f"获取 MCP 工具: {tools}") + + for tool in tools.tools: + + async def func(query: core_entities.Query, **kwargs): + result = await self.session.call_tool(tool.name, kwargs) + if result.isError: + raise Exception(result.content[0].text) + return result.content[0].text + + func.__name__ = tool.name + + self.functions.append(tools_entities.LLMFunction( + name=tool.name, + human_desc=tool.description, + description=tool.description, + parameters=tool.inputSchema, + func=func, + )) + +@loader.loader_class("mcp") +class MCPLoader(loader.ToolLoader): + """MCP 工具加载器。 + + 在此加载器中管理所有与 MCP Server 的连接。 + """ + + sessions: dict[str, RuntimeMCPSession] = {} + + _last_listed_functions: list[tools_entities.LLMFunction] = [] + + async def initialize(self): + + for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []): + if not server_config["enable"]: + continue + session = RuntimeMCPSession(server_config["name"], server_config, self.ap) + await session.initialize() + # self.ap.event_loop.create_task(session.initialize()) + self.sessions[server_config["name"]] = session + + async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]: + all_functions = [] + + for server_name, session in self.sessions.items(): + all_functions.extend(session.functions) + + self._last_listed_functions = all_functions + + return all_functions + + async def has_tool(self, name: str) -> bool: + return name in [f.name for f in self._last_listed_functions] + + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + for server_name, session in self.sessions.items(): + for function in session.functions: + if function.name == name: + return await function.func(query, **parameters) + + raise ValueError(f"未找到工具: {name}") \ No newline at end of file diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 9986d3ab..1688937d 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -6,7 +6,7 @@ import traceback from ...core import app, entities as core_entities from . import entities, loader as tools_loader from ...plugin import context as plugin_context -from .loaders import plugin +from .loaders import plugin, mcp class ToolManager: diff --git a/requirements.txt b/requirements.txt index a867e055..243d2da7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,7 @@ dingtalk_stream dashscope python-telegram-bot certifi +mcp # indirect taskgroup==0.0.0a4 \ No newline at end of file diff --git a/templates/provider.json b/templates/provider.json index 76aa0218..e34f6d32 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -138,5 +138,8 @@ "date": "2023-08-10" } } + }, + "mcp": { + "servers": [] } } \ No newline at end of file diff --git a/templates/schema/provider.json b/templates/schema/provider.json index 0d2d2f01..def2cc12 100644 --- a/templates/schema/provider.json +++ b/templates/schema/provider.json @@ -520,6 +520,87 @@ } } } + }, + "mcp": { + "type": "object", + "title": "MCP 配置", + "properties": { + "servers": { + "type": "array", + "title": "MCP 服务器配置", + "default": [], + "items": { + "type": "object", + "oneOf": [ + { + "title": "Stdio 模式服务器", + "properties": { + "mode": { + "type": "string", + "title": "模式", + "const": "stdio" + }, + "enable": { + "type": "boolean", + "title": "启用" + }, + "name": { + "type": "string", + "title": "名称" + }, + "command": { + "type": "string", + "title": "启动命令" + }, + "args": { + "type": "array", + "title": "启动参数", + "items": { + "type": "string" + }, + "default": [] + }, + "env": { + "type": "object", + "default": {} + } + } + }, + { + "title": "SSE 模式服务器", + "properties": { + "mode": { + "type": "string", + "title": "模式", + "const": "sse" + }, + "enable": { + "type": "boolean", + "title": "启用" + }, + "name": { + "type": "string", + "title": "名称" + }, + "url": { + "type": "string", + "title": "URL" + }, + "headers": { + "type": "object", + "default": {} + }, + "timeout": { + "type": "number", + "title": "请求超时时间", + "default": 10 + } + } + } + ] + } + } + } } } } \ No newline at end of file