mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-26 23:44:19 +00:00
refactor: mcp server datastructure
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import traceback
|
||||
|
||||
|
||||
from ... import group
|
||||
@@ -13,36 +14,18 @@ class MCPRouterGroup(group.RouterGroup):
|
||||
async def _() -> str:
|
||||
"""获取MCP服务器列表"""
|
||||
if quart.request.method == 'GET':
|
||||
servers = await self.ap.mcp_service.get_mcp_servers()
|
||||
servers = await self.ap.mcp_service.get_mcp_servers(contain_runtime_info=True)
|
||||
|
||||
servers_with_status = []
|
||||
# 获取MCP工具加载器
|
||||
mcp_loader = self.ap.tool_mgr.mcp_tool_loader
|
||||
|
||||
for server in servers:
|
||||
# 从运行中的会话获取工具数量
|
||||
tools_count = 0
|
||||
if mcp_loader:
|
||||
session = mcp_loader.sessions.get(server['name'])
|
||||
if session:
|
||||
tools_count = len(session.functions)
|
||||
|
||||
server_info = {
|
||||
**server,
|
||||
'tools': tools_count,
|
||||
}
|
||||
servers_with_status.append(server_info)
|
||||
|
||||
return self.success(data={'servers': servers_with_status})
|
||||
return self.success(data={'servers': servers})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
data = await quart.request.json
|
||||
data = data['source']
|
||||
|
||||
try:
|
||||
uuid = await self.ap.mcp_service.create_mcp_server(data)
|
||||
return self.success(data={'uuid': uuid})
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return self.http_status(500, -1, f'Failed to create MCP server: {str(e)}')
|
||||
|
||||
@self.route('/servers/<server_name>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
||||
|
||||
@@ -40,7 +40,6 @@ class RuntimeMCPServer:
|
||||
self.session = RuntimeMCPSession(
|
||||
self.mcp_server_entity.name, mixed_config, self.mcp_server_entity.enable, self.ap
|
||||
)
|
||||
await self.session.initialize()
|
||||
await self.session.start()
|
||||
|
||||
async def _test_mcp_server_task(self, task_context: taskmgr.TaskContext):
|
||||
@@ -102,14 +101,29 @@ class MCPService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_mcp_servers(self) -> list[dict]:
|
||||
async def get_mcp_servers(self, contain_runtime_info: bool = False) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
|
||||
|
||||
servers = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers]
|
||||
serialized_servers = [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server) for server in servers
|
||||
]
|
||||
if contain_runtime_info:
|
||||
for server in serialized_servers:
|
||||
session = self.ap.tool_mgr.mcp_tool_loader.get_session(server['name'])
|
||||
|
||||
runtime_info = None
|
||||
|
||||
if session:
|
||||
runtime_info = session.get_runtime_info_dict()
|
||||
|
||||
server['runtime_info'] = runtime_info if runtime_info else None
|
||||
|
||||
return serialized_servers
|
||||
|
||||
async def create_mcp_server(self, server_data: dict) -> str:
|
||||
server_data['uuid'] = str(uuid.uuid4())
|
||||
print('server_data:', server_data)
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_mcp.MCPServer).values(server_data))
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
|
||||
@@ -58,7 +58,7 @@ class PluginRuntimeConnector:
|
||||
|
||||
async def heartbeat_loop(self):
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await asyncio.sleep(20)
|
||||
try:
|
||||
await self.ping_plugin_runtime()
|
||||
self.ap.logger.debug('Heartbeat to plugin runtime success.')
|
||||
|
||||
@@ -33,6 +33,10 @@ class RuntimeMCPSession:
|
||||
|
||||
enable: bool
|
||||
|
||||
connected: bool
|
||||
|
||||
last_test_error_message: str
|
||||
|
||||
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
|
||||
self.server_name = server_name
|
||||
self.server_config = server_config
|
||||
@@ -43,6 +47,9 @@ class RuntimeMCPSession:
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.functions = []
|
||||
|
||||
self.connected = False
|
||||
self.last_test_error_message = ''
|
||||
|
||||
async def _init_stdio_python_server(self):
|
||||
server_params = StdioServerParameters(
|
||||
command=self.server_config['command'],
|
||||
@@ -64,6 +71,7 @@ class RuntimeMCPSession:
|
||||
self.server_config['url'],
|
||||
headers=self.server_config.get('headers', {}),
|
||||
timeout=self.server_config.get('timeout', 10),
|
||||
sse_read_timeout=self.server_config.get('ssereadtimeout', 30),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -73,47 +81,66 @@ class RuntimeMCPSession:
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def start(self):
|
||||
if not self.enable:
|
||||
return
|
||||
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
else:
|
||||
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
|
||||
try:
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
else:
|
||||
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
|
||||
|
||||
tools = await self.session.list_tools()
|
||||
tools = await self.session.list_tools()
|
||||
|
||||
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
|
||||
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
|
||||
|
||||
for tool in tools.tools:
|
||||
for tool in tools.tools:
|
||||
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
return result.content[0].text
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
return result.content[0].text
|
||||
|
||||
func.__name__ = tool.name
|
||||
func.__name__ = tool.name
|
||||
|
||||
self.functions.append(
|
||||
resource_tool.LLMTool(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
self.functions.append(
|
||||
resource_tool.LLMTool(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.connected = True
|
||||
self.last_test_error_message = ''
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
self.last_test_error_message = str(e)
|
||||
raise e
|
||||
|
||||
def get_tools(self) -> list[resource_tool.LLMTool]:
|
||||
return self.functions
|
||||
|
||||
def get_runtime_info_dict(self) -> dict:
|
||||
return {
|
||||
'connected': self.connected,
|
||||
'error_message': self.last_test_error_message,
|
||||
'tool_count': len(self.get_tools()),
|
||||
'tools': [
|
||||
{
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
}
|
||||
for tool in self.get_tools()
|
||||
],
|
||||
}
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭会话并清理资源"""
|
||||
try:
|
||||
@@ -156,9 +183,9 @@ class MCPLoader(loader.ToolLoader):
|
||||
servers = result.all()
|
||||
|
||||
for server in servers:
|
||||
server_config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
|
||||
|
||||
async def load_mcp_server_task():
|
||||
async def load_mcp_server_task(server_config: dict):
|
||||
self.ap.logger.debug(f'Loading MCP server {server_config}')
|
||||
try:
|
||||
session = await self.load_mcp_server(server_config)
|
||||
@@ -180,7 +207,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
|
||||
self.ap.logger.debug(f'Started MCP server {server_config["name"]}({server_config["uuid"]})')
|
||||
|
||||
task = asyncio.create_task(load_mcp_server_task())
|
||||
task = asyncio.create_task(load_mcp_server_task(config))
|
||||
self._startup_load_tasks.append(task)
|
||||
|
||||
async def load_mcp_server(self, server_config: dict) -> RuntimeMCPSession:
|
||||
@@ -207,7 +234,6 @@ class MCPLoader(loader.ToolLoader):
|
||||
}
|
||||
|
||||
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
|
||||
await session.initialize()
|
||||
|
||||
return session
|
||||
|
||||
|
||||
Reference in New Issue
Block a user