feat: add supports for loading mcp server as LLM tools provider

This commit is contained in:
Junyan Qin
2025-03-19 12:08:47 +08:00
parent ebe0b2f335
commit 40275c3ef1
9 changed files with 253 additions and 1 deletions

View File

@@ -34,6 +34,7 @@ required_deps = {
"dashscope": "dashscope",
"telegram": "python-telegram-bot",
"certifi": "certifi",
"mcp": "mcp",
}

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("mcp-config", 37)
class MCPConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'mcp' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['mcp'] = {
"servers": []
}
await self.ap.provider_cfg.dump_config()

View File

@@ -12,6 +12,8 @@ from ..migrations import m020_wecom_config, m021_lark_config, m022_lmstudio_conf
from ..migrations import m026_qqofficial_config, m027_wx_official_account_config, m028_aliyun_requester_config
from ..migrations import m029_dashscope_app_api_config, m030_lark_config_cmpl, m031_dingtalk_config, m032_volcark_config
from ..migrations import m033_dify_thinking_config, m034_gewechat_file_url_config, m035_wxoa_mode, m036_wxoa_loading_message
from ..migrations import m037_mcp_config
@stage.stage_class("MigrationStage")
class MigrationStage(stage.BootingStage):

View File

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import typing
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from .. import loader, entities as tools_entities
from ....core import app, entities as core_entities
class RuntimeMCPSession:
"""运行时 MCP 会话"""
ap: app.Application
server_name: str
server_config: dict
session: ClientSession
exit_stack: AsyncExitStack
functions: list[tools_entities.LLMFunction] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.exit_stack = AsyncExitStack()
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
command=self.server_config["command"],
args=self.server_config["args"],
env=self.server_config["env"],
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio, write)
)
await self.session.initialize()
async def _init_sse_server(self):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
self.server_config["url"],
headers=self.server_config.get("headers", {}),
timeout=self.server_config.get("timeout", 10),
)
)
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(sseio, write)
)
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
if self.server_config["mode"] == "stdio":
await self._init_stdio_python_server()
elif self.server_config["mode"] == "sse":
await self._init_sse_server()
else:
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
tools = await self.session.list_tools()
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
for tool in tools.tools:
async def func(query: core_entities.Query, **kwargs):
result = await self.session.call_tool(tool.name, kwargs)
if result.isError:
raise Exception(result.content[0].text)
return result.content[0].text
func.__name__ = tool.name
self.functions.append(tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
))
@loader.loader_class("mcp")
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
sessions: dict[str, RuntimeMCPSession] = {}
_last_listed_functions: list[tools_entities.LLMFunction] = []
async def initialize(self):
for server_config in self.ap.provider_cfg.data.get("mcp", {}).get("servers", []):
if not server_config["enable"]:
continue
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config["name"]] = session
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
all_functions = []
for server_name, session in self.sessions.items():
all_functions.extend(session.functions)
self._last_listed_functions = all_functions
return all_functions
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}")

View File

@@ -6,7 +6,7 @@ import traceback
from ...core import app, entities as core_entities
from . import entities, loader as tools_loader
from ...plugin import context as plugin_context
from .loaders import plugin
from .loaders import plugin, mcp
class ToolManager:

View File

@@ -33,6 +33,7 @@ dingtalk_stream
dashscope
python-telegram-bot
certifi
mcp
# indirect
taskgroup==0.0.0a4

View File

@@ -138,5 +138,8 @@
"date": "2023-08-10"
}
}
},
"mcp": {
"servers": []
}
}

View File

@@ -520,6 +520,87 @@
}
}
}
},
"mcp": {
"type": "object",
"title": "MCP 配置",
"properties": {
"servers": {
"type": "array",
"title": "MCP 服务器配置",
"default": [],
"items": {
"type": "object",
"oneOf": [
{
"title": "Stdio 模式服务器",
"properties": {
"mode": {
"type": "string",
"title": "模式",
"const": "stdio"
},
"enable": {
"type": "boolean",
"title": "启用"
},
"name": {
"type": "string",
"title": "名称"
},
"command": {
"type": "string",
"title": "启动命令"
},
"args": {
"type": "array",
"title": "启动参数",
"items": {
"type": "string"
},
"default": []
},
"env": {
"type": "object",
"default": {}
}
}
},
{
"title": "SSE 模式服务器",
"properties": {
"mode": {
"type": "string",
"title": "模式",
"const": "sse"
},
"enable": {
"type": "boolean",
"title": "启用"
},
"name": {
"type": "string",
"title": "名称"
},
"url": {
"type": "string",
"title": "URL"
},
"headers": {
"type": "object",
"default": {}
},
"timeout": {
"type": "number",
"title": "请求超时时间",
"default": 10
}
}
}
]
}
}
}
}
}
}