Files
LangBot/src/langbot/pkg/provider/tools/loaders/mcp.py
youhuanghe 791d052687 feat(box/mcp): instance-based orphan cleanup, error classification, session API, and integration tests
## Changes

  ### Precise orphan container cleanup
  - Runtime generates a unique instance_id on startup
  - Every container gets a `langbot.box.instance_id` label
  - `cleanup_orphaned_containers()` only removes containers from
    previous instances, preserving containers owned by the current one
  - Containers from older versions (no label) are also cleaned up
  - `cleanup_orphaned_containers` added to `BaseSandboxBackend` as
    a no-op default method, removing hasattr duck-typing

  ### Fine-grained MCP error classification
  - New `MCPSessionErrorPhase` enum with 7 phases: session_create,
    dep_install, process_start, relay_connect, mcp_init, runtime,
    tool_call
  - Each phase in `_init_box_stdio_server()` sets the error phase
    before re-raising, enabling precise failure diagnosis
  - `retry_count` tracked across retry attempts
  - `get_runtime_info_dict()` exposes `error_phase` and `retry_count`

  ### GET /v1/sessions/{id} API
  - `BoxRuntime.get_session()` returns session details including
    managed process info when present
  - `handle_get_session` HTTP handler + route in server.py
  - `BoxRuntimeClient.get_session()` abstract method + remote impl

  ### stdio defaults to Box when runtime is available
  - `_uses_box_stdio()` checks `box_service.available` instead of
    requiring explicit `box` key in server_config
  - `BoxService.initialize()` catches runtime errors gracefully,
    sets `available=False` instead of crashing LangBot startup
  - When no container runtime exists, stdio MCP falls back to
    host-direct execution

  ### Code quality (from /simplify review)
  - Extracted `_VENV_DIRS` / `_VENV_BIN_DIRS` module-level constants
  - Removed dead `_box_network_mode()` method and unused `bc` variable
  - Fixed broken import `from ....box.models` → `from ...box.models`
  - Cached `_resolve_host_path()` result — computed once, passed through
  - Config hash now includes `host_path` field
  - Batched orphan cleanup into single `rm -f` command

  ### Session leak fix
  - `_cleanup_box_stdio_session()` now runs in `_lifecycle_loop`'s
    finally block, covering all exit paths (normal shutdown, error,
    retry, final failure)

  ### Integration tests
  - 6 end-to-end tests covering managed process lifecycle, WebSocket
    stdio bidirectional IO, session cleanup verification, single
    session query, process exit detection, and orphan cleanup safety
2026-05-04 21:23:23 +08:00

816 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import enum
import os
import typing
from contextlib import AsyncExitStack
import traceback
from langbot_plugin.api.entities.events import pipeline_query
import sqlalchemy
import asyncio
import httpx
import pydantic
import uuid as uuid_module
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.client.websocket import websocket_client
from .. import loader
from ....core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ....entity.persistence import mcp as persistence_mcp
class MCPSessionStatus(enum.Enum):
CONNECTING = 'connecting'
CONNECTED = 'connected'
ERROR = 'error'
class MCPSessionErrorPhase(enum.Enum):
"""Which phase of the MCP lifecycle failed."""
SESSION_CREATE = 'session_create'
DEP_INSTALL = 'dep_install'
PROCESS_START = 'process_start'
RELAY_CONNECT = 'relay_connect'
MCP_INIT = 'mcp_init'
RUNTIME = 'runtime'
TOOL_CALL = 'tool_call'
_VENV_DIRS = frozenset({'.venv', 'venv', 'env', '.env'})
_VENV_BIN_DIRS = frozenset({'bin', 'Scripts'})
class MCPServerBoxConfig(pydantic.BaseModel):
"""Structured configuration for running an MCP server inside a Box container."""
image: str | None = None
network: str = 'on' # MCP servers need network for dependency installation
host_path: str | None = None
host_path_mode: str = 'ro' # MCP servers default to read-only mount
env: dict[str, str] = pydantic.Field(default_factory=dict)
startup_timeout_sec: int = 120 # Longer default to allow pip install
cpus: float | None = None
memory_mb: int | None = None
pids_limit: int | None = None
read_only_rootfs: bool | None = None
model_config = pydantic.ConfigDict(extra='ignore')
class RuntimeMCPSession:
"""运行时 MCP 会话"""
ap: app.Application
server_name: str
server_uuid: str
server_config: dict
session: ClientSession | None
exit_stack: AsyncExitStack
functions: list[resource_tool.LLMTool] = []
enable: bool
# connected: bool
status: MCPSessionStatus
_lifecycle_task: asyncio.Task | None
_shutdown_event: asyncio.Event
_ready_event: asyncio.Event
error_message: str | None = None
error_phase: MCPSessionErrorPhase | None = None
retry_count: int = 0
def __init__(self, server_name: str, server_config: dict, enable: bool, ap: app.Application):
self.server_name = server_name
self.server_uuid = server_config.get('uuid', '')
self.server_config = server_config
self.ap = ap
self.enable = enable
self.session = None
self.exit_stack = AsyncExitStack()
self.functions = []
self.status = MCPSessionStatus.CONNECTING
self._lifecycle_task = None
self._shutdown_event = asyncio.Event()
self._ready_event = asyncio.Event()
# Parse box config once
self.box_config = MCPServerBoxConfig.model_validate(
server_config.get('box', {})
)
async def _init_stdio_python_server(self):
if self._uses_box_stdio():
await self._init_box_stdio_server()
return
server_params = StdioServerParameters(
command=self.server_config['command'],
args=self.server_config['args'],
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await self.session.initialize()
async def _init_box_stdio_server(self):
box_service = self.ap.box_service
session_id = self._build_box_session_id()
host_path = self._resolve_host_path()
session_payload = self._build_box_session_payload(session_id, host_path)
# Phase: session creation
try:
await box_service.create_session(
session_payload,
skip_host_mount_validation=True,
)
except Exception as e:
self.error_phase = MCPSessionErrorPhase.SESSION_CREATE
raise
# Phase: dependency installation
if host_path:
install_cmd = self._detect_install_command(host_path)
if install_cmd:
self.ap.logger.info(
f'MCP server {self.server_name}: installing dependencies in Box '
f'with: {install_cmd}'
)
exec_payload = dict(session_payload)
exec_payload['cmd'] = install_cmd
exec_payload['timeout_sec'] = self.box_config.startup_timeout_sec or 120
try:
result = await box_service.client.execute(
box_service.build_spec(exec_payload, skip_host_mount_validation=True)
)
except Exception as e:
self.error_phase = MCPSessionErrorPhase.DEP_INSTALL
raise
if not result.ok:
self.error_phase = MCPSessionErrorPhase.DEP_INSTALL
stderr_preview = (result.stderr or '')[:500]
raise Exception(
f'Dependency install failed (exit code {result.exit_code}): '
f'{stderr_preview}'
)
# Phase: managed process start
try:
await box_service.start_managed_process(
session_id,
self._build_box_process_payload(host_path),
)
except Exception as e:
self.error_phase = MCPSessionErrorPhase.PROCESS_START
raise
# Phase: WebSocket relay connection
try:
websocket_url = box_service.get_managed_process_websocket_url(session_id)
transport = await self.exit_stack.enter_async_context(websocket_client(websocket_url))
read_stream, write_stream = transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
except Exception as e:
self.error_phase = MCPSessionErrorPhase.RELAY_CONNECT
raise
# Phase: MCP protocol initialization
try:
await self.session.initialize()
except Exception as e:
self.error_phase = MCPSessionErrorPhase.MCP_INIT
raise
async def _init_sse_server(self):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
self.server_config['url'],
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
sse_read_timeout=self.server_config.get('ssereadtimeout', 30),
)
)
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(sseio, write))
await self.session.initialize()
async def _init_streamable_http_server(self):
transport = await self.exit_stack.enter_async_context(
streamable_http_client(
self.server_config['url'],
http_client=httpx.AsyncClient(
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
follow_redirects=True,
),
)
)
read, write, _ = transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
_MAX_RETRIES = 3
_RETRY_DELAYS = [2, 4, 8]
async def _lifecycle_loop(self):
"""Manage the full MCP session lifecycle in a background task."""
try:
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
elif self.server_config['mode'] == 'http':
await self._init_streamable_http_server()
else:
raise ValueError(f'Unknown MCP server mode: {self.server_name}: {self.server_config}')
await self.refresh()
self.status = MCPSessionStatus.CONNECTED
# Notify start() that connection is established
self._ready_event.set()
# Wait for shutdown signal, with optional health monitoring for Box stdio
if self._uses_box_stdio():
monitor_task = asyncio.create_task(self._monitor_box_process_health())
shutdown_task = asyncio.create_task(self._shutdown_event.wait())
done, pending = await asyncio.wait(
[shutdown_task, monitor_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
for task in done:
if task is monitor_task and not self._shutdown_event.is_set():
self.error_phase = MCPSessionErrorPhase.RUNTIME
raise Exception('Box managed process exited unexpectedly')
else:
await self._shutdown_event.wait()
except Exception as e:
self.status = MCPSessionStatus.ERROR
self.error_message = str(e)
self.ap.logger.error(f'Error in MCP session lifecycle {self.server_name}: {e}\n{traceback.format_exc()}')
# Do NOT set _ready_event here — let _lifecycle_loop_with_retry
# handle retries first. It will set the event when all retries
# are exhausted or on success.
raise # Re-raise so _lifecycle_loop_with_retry can catch it
finally:
# Clean up all resources in the same task
try:
if self.exit_stack:
await self.exit_stack.aclose()
self.exit_stack = AsyncExitStack()
self.functions.clear()
self.session = None
except Exception as e:
self.ap.logger.error(f'Error cleaning up MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
finally:
await self._cleanup_box_stdio_session()
async def _lifecycle_loop_with_retry(self):
"""Wrap _lifecycle_loop with retry and exponential backoff."""
for attempt in range(self._MAX_RETRIES + 1):
try:
await self._lifecycle_loop()
return # Normal shutdown, don't retry
except Exception as e:
self.retry_count = attempt + 1
if self._shutdown_event.is_set():
return # Shutdown requested, don't retry
if attempt >= self._MAX_RETRIES:
self.status = MCPSessionStatus.ERROR
self.error_message = f'Failed after {self._MAX_RETRIES + 1} attempts: {e}'
self._ready_event.set()
return
delay = self._RETRY_DELAYS[attempt]
self.ap.logger.warning(
f'MCP session {self.server_name} failed (attempt {attempt + 1}), '
f'retrying in {delay}s: {e}'
)
await self._cleanup_box_stdio_session()
# Reset status for retry
self.status = MCPSessionStatus.CONNECTING
self.error_message = None
self.error_phase = None
await asyncio.sleep(delay)
async def _monitor_box_process_health(self):
"""Poll managed process status; return when process exits."""
from ...box.models import BoxManagedProcessStatus
session_id = self._build_box_session_id()
while not self._shutdown_event.is_set():
try:
info = await self.ap.box_service.client.get_managed_process(session_id)
if isinstance(info, dict):
status = info.get('status', '')
else:
status = getattr(info, 'status', '')
if status == BoxManagedProcessStatus.EXITED.value or status == BoxManagedProcessStatus.EXITED:
return
except Exception:
return # Process or session gone
await asyncio.sleep(5)
async def start(self):
if not self.enable:
return
# Create background task for lifecycle management with retry
self._lifecycle_task = asyncio.create_task(self._lifecycle_loop_with_retry())
# Wait for connection or failure (with timeout)
startup_timeout = self.box_config.startup_timeout_sec if self._uses_box_stdio() else 30.0
try:
await asyncio.wait_for(self._ready_event.wait(), timeout=startup_timeout)
except asyncio.TimeoutError:
self.status = MCPSessionStatus.ERROR
raise Exception(f'Connection timeout after {startup_timeout} seconds')
# Check for errors
if self.status == MCPSessionStatus.ERROR:
raise Exception('Connection failed, please check URL')
async def refresh(self):
if not self.session:
return
self.functions.clear()
tools = await self.session.list_tools()
self.ap.logger.debug(f'Refresh MCP tools: {tools}')
for tool in tools.tools:
async def func(*, _tool=tool, **kwargs):
if not self.session:
raise Exception('MCP session is not connected')
result = await self.session.call_tool(_tool.name, kwargs)
if result.isError:
error_texts = []
for content in result.content:
if content.type == 'text':
error_texts.append(content.text)
raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
result_contents: list[provider_message.ContentElement] = []
for content in result.content:
if content.type == 'text':
result_contents.append(provider_message.ContentElement.from_text(content.text))
elif content.type == 'image':
result_contents.append(provider_message.ContentElement.from_image_base64(content.image_base64))
elif content.type == 'resource':
# TODO: Handle resource content
pass
return result_contents
func.__name__ = tool.name
self.functions.append(
resource_tool.LLMTool(
name=tool.name,
human_desc=tool.description or '',
description=tool.description or '',
parameters=tool.inputSchema,
func=func,
)
)
def get_tools(self) -> list[resource_tool.LLMTool]:
return self.functions
def get_runtime_info_dict(self) -> dict:
info = {
'status': self.status.value,
'error_message': self.error_message,
'error_phase': self.error_phase.value if self.error_phase else None,
'retry_count': self.retry_count,
'tool_count': len(self.get_tools()),
'tools': [
{
'name': tool.name,
'description': tool.description,
}
for tool in self.get_tools()
],
}
if self._uses_box_stdio():
info['box_session_id'] = self._build_box_session_id()
info['box_enabled'] = True
return info
async def shutdown(self):
"""关闭会话并清理资源"""
try:
# 设置shutdown事件通知lifecycle任务退出
self._shutdown_event.set()
# 等待lifecycle任务完成带超时
if self._lifecycle_task and not self._lifecycle_task.done():
try:
await asyncio.wait_for(self._lifecycle_task, timeout=5.0)
except asyncio.TimeoutError:
self.ap.logger.warning(f'MCP session {self.server_name} shutdown timeout, cancelling task')
self._lifecycle_task.cancel()
try:
await self._lifecycle_task
except asyncio.CancelledError:
pass
self.ap.logger.info(f'MCP session {self.server_name} shutdown complete')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {self.server_name}: {e}\n{traceback.format_exc()}')
def _uses_box_stdio(self) -> bool:
"""Check whether this stdio MCP server should run inside a Box container.
Returns True when mode is stdio AND the Box runtime is available.
An explicit ``box`` key in server_config is NOT required — if the
runtime is reachable, stdio servers default to Box isolation.
"""
if self.server_config.get('mode') != 'stdio':
return False
try:
return getattr(self.ap.box_service, 'available', False)
except Exception:
return False
def _build_box_session_id(self) -> str:
return f'mcp-{self.server_uuid}'
def _rewrite_path(self, path: str, host_path: str | None) -> str:
"""Rewrite host path prefix to container /workspace prefix."""
if not host_path or not path:
return path
normalized_host = os.path.realpath(host_path)
if path.startswith(normalized_host + '/'):
return '/workspace' + path[len(normalized_host):]
if path == normalized_host:
return '/workspace'
return path
def _infer_host_path(self) -> str | None:
"""Try to infer host_path from command and args absolute paths.
Detects virtualenv patterns (e.g. .venv/bin/python) and walks up
to the project root rather than using the bin directory.
"""
candidates = []
parts = [self.server_config.get('command', '')] + self.server_config.get('args', [])
for part in parts:
if not os.path.isabs(part):
continue
# Use the raw path for venv detection (before resolving symlinks)
# because .venv/bin/python is often a symlink to the system python.
if os.path.exists(part):
directory = os.path.dirname(part)
directory = self._unwrap_venv_path(directory)
candidates.append(os.path.realpath(directory))
if not candidates:
return None
common = os.path.commonpath(candidates)
return common if common != '/' else None
@staticmethod
def _unwrap_venv_path(directory: str) -> str:
"""If directory looks like a virtualenv bin dir, return the project root.
Recognized patterns:
/project/.venv/bin -> /project
/project/venv/bin -> /project
/project/.venv/Scripts -> /project (Windows)
/project/env/bin -> /project
"""
parts = directory.replace('\\', '/').split('/')
# Look for patterns like .../(.venv|venv|env)/(bin|Scripts)
for i in range(len(parts) - 1, 0, -1):
if parts[i] in _VENV_BIN_DIRS and i >= 1:
venv_dir = parts[i - 1]
if venv_dir in _VENV_DIRS:
# Return everything before the venv directory
project_root = '/'.join(parts[:i - 1])
return project_root if project_root else '/'
return directory
def _resolve_host_path(self) -> str | None:
"""Resolve the effective host_path: explicit config > inference."""
return self.box_config.host_path or self._infer_host_path()
@staticmethod
def _detect_install_command(host_path: str) -> str | None:
"""Detect how to install dependencies from the mounted project.
Copies the project to a writable temp directory before installing,
because /workspace may be mounted read-only and pip needs to write
build artifacts in the source tree.
"""
_COPY_AND_INSTALL = (
'cp -r /workspace /tmp/_mcp_src'
' && pip install --no-cache-dir /tmp/_mcp_src'
' && rm -rf /tmp/_mcp_src'
)
_INSTALL_REQUIREMENTS = 'pip install --no-cache-dir -r /workspace/requirements.txt'
if os.path.isfile(os.path.join(host_path, 'pyproject.toml')):
return _COPY_AND_INSTALL
if os.path.isfile(os.path.join(host_path, 'setup.py')):
return _COPY_AND_INSTALL
if os.path.isfile(os.path.join(host_path, 'requirements.txt')):
return _INSTALL_REQUIREMENTS
return None
def _build_box_session_payload(self, session_id: str, host_path: str | None = None) -> dict:
bc = self.box_config
if host_path is None:
host_path = self._resolve_host_path()
payload: dict[str, typing.Any] = {
'session_id': session_id,
'workdir': '/workspace',
'env': bc.env,
# MCP sessions need network for dependency install and writable rootfs
'network': bc.network,
'read_only_rootfs': bc.read_only_rootfs if bc.read_only_rootfs is not None else False,
}
if host_path:
payload['host_path'] = host_path
payload['host_path_mode'] = bc.host_path_mode
for key in ('image', 'cpus', 'memory_mb', 'pids_limit'):
val = getattr(bc, key)
if val is not None:
payload[key] = val if not isinstance(val, enum.Enum) else val.value
return payload
def _build_box_process_payload(self, host_path: str | None = None) -> dict:
if host_path is None:
host_path = self._resolve_host_path()
command = self.server_config['command']
args = self.server_config.get('args', [])
cwd = '/workspace'
if host_path:
# When host_path is resolved, we install deps in-container rather
# than relying on the host venv. Rewrite paths so the container
# sees /workspace/... but replace venv python with plain "python".
command = self._rewrite_venv_command(command, host_path)
args = [self._rewrite_path(a, host_path) for a in args]
cwd = self._rewrite_path(cwd, host_path)
return {
'command': command,
'args': args,
'env': self.server_config.get('env', {}),
'cwd': cwd,
}
def _rewrite_venv_command(self, command: str, host_path: str) -> str:
"""Rewrite command: if it points to a venv python, use plain 'python'."""
if not host_path or not command:
return command
normalized_host = os.path.realpath(host_path)
if not command.startswith(normalized_host + '/'):
return command
# Check if command is a venv python interpreter
rel = command[len(normalized_host) + 1:] # e.g. ".venv/bin/python"
parts = rel.replace('\\', '/').split('/')
# Match patterns like .venv/bin/python*, venv/bin/python*, etc.
if (len(parts) >= 3
and parts[0] in _VENV_DIRS
and parts[1] in _VENV_BIN_DIRS
and parts[2].startswith('python')):
return 'python'
# Not a venv python — do normal path rewrite
return self._rewrite_path(command, host_path)
async def _cleanup_box_stdio_session(self) -> None:
if not self._uses_box_stdio():
return
try:
await self.ap.box_service.client.delete_session(self._build_box_session_id())
except Exception as e:
self.ap.logger.warning(f'Failed to cleanup Box session for MCP server {self.server_name}: {e}')
# @loader.loader_class('mcp')
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
sessions: dict[str, RuntimeMCPSession]
_last_listed_functions: list[resource_tool.LLMTool]
_hosted_mcp_tasks: list[asyncio.Task]
def __init__(self, ap: app.Application):
super().__init__(ap)
self.sessions = {}
self._last_listed_functions = []
self._hosted_mcp_tasks = []
async def initialize(self):
await self.load_mcp_servers_from_db()
async def load_mcp_servers_from_db(self):
self.ap.logger.info('Loading MCP servers from db...')
self.sessions = {}
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_mcp.MCPServer))
servers = result.all()
for server in servers:
config = self.ap.persistence_mgr.serialize_model(persistence_mcp.MCPServer, server)
task = asyncio.create_task(self.host_mcp_server(config))
self._hosted_mcp_tasks.append(task)
async def host_mcp_server(self, server_config: dict):
self.ap.logger.debug(f'Loading MCP server {server_config}')
try:
session = await self.load_mcp_server(server_config)
self.sessions[server_config['name']] = session
except Exception as e:
self.ap.logger.error(
f'Failed to load MCP server from db: {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
)
return
self.ap.logger.debug(f'Starting MCP server {server_config["name"]}({server_config["uuid"]})')
try:
await session.start()
except Exception as e:
self.ap.logger.error(
f'Failed to start MCP server {server_config["name"]}({server_config["uuid"]}): {e}\n{traceback.format_exc()}'
)
return
self.ap.logger.debug(f'Started MCP server {server_config["name"]}({server_config["uuid"]})')
async def load_mcp_server(self, server_config: dict) -> RuntimeMCPSession:
"""加载 MCP 服务器到运行时
Args:
server_config: 服务器配置字典,必须包含:
- name: 服务器名称
- mode: 连接模式 (stdio/sse/http)
- enable: 是否启用
- extra_args: 额外的配置参数 (可选)
"""
uuid_ = server_config.get('uuid')
if not uuid_:
self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
uuid_ = str(uuid_module.uuid4())
server_config['uuid'] = uuid_
name = server_config['name']
uuid = server_config['uuid']
mode = server_config['mode']
enable = server_config['enable']
extra_args = server_config.get('extra_args', {})
mixed_config = {
'name': name,
'uuid': uuid,
'mode': mode,
'enable': enable,
**extra_args,
}
session = RuntimeMCPSession(name, mixed_config, enable, self.ap)
return session
async def get_tools(self, bound_mcp_servers: list[str] | None = None) -> list[resource_tool.LLMTool]:
all_functions = []
for session in self.sessions.values():
# If bound_mcp_servers is specified, only include tools from those servers
if bound_mcp_servers is not None:
if session.server_uuid in bound_mcp_servers:
all_functions.extend(session.get_tools())
else:
# If no bound servers specified, include all tools
all_functions.extend(session.get_tools())
self._last_listed_functions = all_functions
return all_functions
async def has_tool(self, name: str) -> bool:
"""检查工具是否存在"""
for session in self.sessions.values():
for function in session.get_tools():
if function.name == name:
return True
return False
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行工具调用"""
for session in self.sessions.values():
for function in session.get_tools():
if function.name == name:
self.ap.logger.debug(f'Invoking MCP tool: {name} with parameters: {parameters}')
try:
result = await function.func(**parameters)
self.ap.logger.debug(f'MCP tool {name} executed successfully')
return result
except Exception as e:
self.ap.logger.error(f'Error invoking MCP tool {name}: {e}\n{traceback.format_exc()}')
raise
raise ValueError(f'Tool not found: {name}')
async def remove_mcp_server(self, server_name: str):
"""移除 MCP 服务器"""
if server_name not in self.sessions:
self.ap.logger.warning(f'MCP server {server_name} not found in sessions, skipping removal')
return
session = self.sessions.pop(server_name)
await session.shutdown()
self.ap.logger.info(f'Removed MCP server: {server_name}')
def get_session(self, server_name: str) -> RuntimeMCPSession | None:
"""获取指定名称的 MCP 会话"""
return self.sessions.get(server_name)
def has_session(self, server_name: str) -> bool:
"""检查是否存在指定名称的 MCP 会话"""
return server_name in self.sessions
def get_all_server_names(self) -> list[str]:
"""获取所有已加载的 MCP 服务器名称"""
return list(self.sessions.keys())
def get_server_tool_count(self, server_name: str) -> int:
"""获取指定服务器的工具数量"""
session = self.get_session(server_name)
return len(session.get_tools()) if session else 0
def get_all_servers_info(self) -> dict[str, dict]:
"""获取所有服务器的信息"""
info = {}
for server_name, session in self.sessions.items():
info[server_name] = {
'name': server_name,
'mode': session.server_config.get('mode'),
'enable': session.enable,
'tools_count': len(session.get_tools()),
'tool_names': [f.name for f in session.get_tools()],
}
return info
async def shutdown(self):
"""关闭所有工具"""
self.ap.logger.info('Shutting down all MCP sessions...')
for server_name, session in list(self.sessions.items()):
try:
await session.shutdown()
self.ap.logger.debug(f'Shutdown MCP session: {server_name}')
except Exception as e:
self.ap.logger.error(f'Error shutting down MCP session {server_name}: {e}\n{traceback.format_exc()}')
self.sessions.clear()
self.ap.logger.info('All MCP sessions shutdown complete')