fix: harden agent runner runtime boundaries

This commit is contained in:
huanghuoguoguo
2026-06-13 00:17:40 +08:00
parent 2094993afb
commit e7779bd16f
22 changed files with 366 additions and 889 deletions

View File

@@ -16,6 +16,7 @@ from ...entity.persistence.artifact import AgentArtifact
from ...entity.persistence.bstorage import BinaryStorage
_FILE_ARTIFACT_METADATA_KEY = '_langbot_file_artifact'
_ARTIFACT_THREAD_METADATA_KEY = '_langbot_thread_id'
class ArtifactStore:
@@ -56,6 +57,7 @@ class ArtifactStore:
runner_id: str | None = None,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
expires_at: datetime.datetime | None = None,
metadata: dict[str, typing.Any] | None = None,
) -> str:
@@ -92,6 +94,7 @@ class ArtifactStore:
runner_id=runner_id,
bot_id=bot_id,
workspace_id=workspace_id,
thread_id=thread_id,
expires_at=expires_at,
metadata=public_metadata,
content=None,
@@ -113,6 +116,7 @@ class ArtifactStore:
runner_id: str | None = None,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
expires_at: datetime.datetime | None = None,
metadata: dict[str, typing.Any] | None = None,
content: bytes | None = None,
@@ -137,6 +141,7 @@ class ArtifactStore:
runner_id: Runner ID that created this
bot_id: Bot UUID
workspace_id: Workspace ID
thread_id: Thread ID stored as Host-only metadata
expires_at: Expiration time
metadata: Additional metadata
content: Optional content to store in BinaryStorage
@@ -147,6 +152,10 @@ class ArtifactStore:
if artifact_id is None:
artifact_id = str(uuid.uuid4())
metadata_payload = dict(metadata or {})
if thread_id is not None:
metadata_payload[_ARTIFACT_THREAD_METADATA_KEY] = thread_id
# If content provided, store in BinaryStorage
if content is not None and storage_key is None:
storage_key = f"artifact:{artifact_id}"
@@ -184,7 +193,7 @@ class ArtifactStore:
workspace_id=workspace_id,
created_at=datetime.datetime.utcnow(),
expires_at=expires_at,
metadata_json=json.dumps(metadata) if metadata else None,
metadata_json=json.dumps(metadata_payload) if metadata_payload else None,
)
session.add(artifact)
await session.commit()
@@ -216,6 +225,22 @@ class ArtifactStore:
return None
return self._row_to_public_dict(row)
async def get_authorization_metadata(
self,
artifact_id: str,
) -> dict[str, typing.Any] | None:
"""Get artifact metadata with Host-only scope fields for authorization."""
row = await self._get_internal_record(artifact_id)
if row is None:
return None
metadata = self._row_to_public_dict(row)
metadata.update({
'bot_id': row.bot_id,
'workspace_id': row.workspace_id,
'thread_id': self._load_metadata(row.metadata_json).get(_ARTIFACT_THREAD_METADATA_KEY),
})
return metadata
async def _get_internal_record(
self,
artifact_id: str,
@@ -455,6 +480,7 @@ class ArtifactStore:
def _public_metadata(metadata_json: str | None) -> dict[str, typing.Any]:
metadata = ArtifactStore._load_metadata(metadata_json)
metadata.pop(_FILE_ARTIFACT_METADATA_KEY, None)
metadata.pop(_ARTIFACT_THREAD_METADATA_KEY, None)
return metadata
@staticmethod

View File

@@ -6,6 +6,10 @@ import typing
from .descriptor import AgentRunnerDescriptor
FORM_ITEM_TYPE_ALIASES = {
'select-llm-model': 'llm-model-selector',
'select-knowledge-bases': 'knowledge-base-multi-selector',
}
LLM_MODEL_SELECTOR_TYPES = {'model-fallback-selector', 'llm-model-selector'}
KB_SELECTOR_TYPES = {'knowledge-base-multi-selector'}
PROMPT_EDITOR_TYPES = {'prompt-editor'}
@@ -13,6 +17,13 @@ FILE_SELECTOR_TYPES = {'file', 'array[file]'}
NONE_SENTINELS = {'', '__none__', '__none'}
def normalize_schema_item_type(item_type: typing.Any) -> typing.Any:
"""Normalize legacy/frontend DynamicForm aliases to protocol field types."""
if not isinstance(item_type, str):
return item_type
return FORM_ITEM_TYPE_ALIASES.get(item_type, item_type)
def iter_schema_items(
descriptor: AgentRunnerDescriptor | None,
field_types: set[str],
@@ -23,7 +34,7 @@ def iter_schema_items(
for item in descriptor.config_schema or []:
if not isinstance(item, dict):
continue
if item.get('type') in field_types:
if normalize_schema_item_type(item.get('type')) in field_types:
yield item
@@ -81,7 +92,8 @@ def extract_model_selection(
continue
value = runner_config.get(field_name, item.get('default'))
if item.get('type') == 'model-fallback-selector':
item_type = normalize_schema_item_type(item.get('type'))
if item_type == 'model-fallback-selector':
if isinstance(value, str):
primary_uuid = value
elif isinstance(value, dict):
@@ -91,7 +103,7 @@ def extract_model_selection(
fallback_uuids = [fallback for fallback in fallbacks if isinstance(fallback, str)]
break
if item.get('type') == 'llm-model-selector' and isinstance(value, str):
if item_type == 'llm-model-selector' and isinstance(value, str):
primary_uuid = value
break
@@ -145,7 +157,8 @@ def extract_config_file_resources(
if not field_name:
continue
value = runner_config.get(field_name, item.get('default'))
if item.get('type') == 'file':
item_type = normalize_schema_item_type(item.get('type'))
if item_type == 'file':
append_file(value)
elif isinstance(value, list):
for entry in value:
@@ -167,7 +180,7 @@ def iter_config_model_refs(
continue
field_name = item.get('name')
field_type = item.get('type')
field_type = normalize_schema_item_type(item.get('type'))
if not field_name or field_name not in runner_config:
continue
@@ -200,7 +213,7 @@ def set_empty_llm_model_selection(
"""Set the first empty schema-defined LLM selector to model_uuid."""
for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES):
field_name = item.get('name')
field_type = item.get('type')
field_type = normalize_schema_item_type(item.get('type'))
if not field_name:
continue

View File

@@ -283,6 +283,9 @@ class AgentRunOrchestrator:
target_run_id = await self._session_registry.find_steering_target(
conversation_id=event.conversation_id,
runner_id=descriptor.id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
thread_id=event.thread_id,
)
if target_run_id is None:
return False

View File

@@ -230,6 +230,7 @@ class AgentRunJournal:
runner_id=runner_id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
thread_id=event.thread_id,
metadata=metadata,
content=content,
)
@@ -371,6 +372,7 @@ class AgentRunJournal:
runner_id=runner_id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
thread_id=event.thread_id,
metadata=metadata,
content=content,
)

View File

@@ -254,6 +254,9 @@ class AgentRunSessionRegistry:
*,
conversation_id: str,
runner_id: str,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
) -> str | None:
"""Find the oldest active run that can accept steering for a conversation."""
async with self._lock:
@@ -264,6 +267,12 @@ class AgentRunSessionRegistry:
continue
if authorization.get('conversation_id') != conversation_id:
continue
if authorization.get('bot_id') != bot_id:
continue
if authorization.get('workspace_id') != workspace_id:
continue
if authorization.get('thread_id') != thread_id:
continue
if not authorization.get('available_apis', {}).get('steering_pull', False):
continue
candidates.append((session['status'].get('started_at', 0), run_id))

View File

@@ -254,6 +254,10 @@ class TranscriptStore:
self,
conversation_id: str,
limit: int = HARD_LIMIT,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
strict_thread: bool = False,
) -> list[provider_message.Message]:
"""Project Transcript rows into the legacy provider Message view.
@@ -265,6 +269,10 @@ class TranscriptStore:
conversation_id=conversation_id,
limit=limit,
direction="backward",
bot_id=bot_id,
workspace_id=workspace_id,
thread_id=thread_id,
strict_thread=strict_thread,
)
messages: list[provider_message.Message] = []

View File

@@ -236,12 +236,18 @@ class BoxService:
if forced_template:
template = forced_template
else:
template = (
(query.pipeline_config or {})
.get('ai', {})
.get('local-agent', {})
.get('box-session-id-template', '{launcher_type}_{launcher_id}')
template = '{launcher_type}_{launcher_id}'
pipeline_config = query.pipeline_config or {}
ai_config = pipeline_config.get('ai', {}) if isinstance(pipeline_config, dict) else {}
runner_selector = ai_config.get('runner', {}) if isinstance(ai_config, dict) else {}
runner_id = runner_selector.get('id') if isinstance(runner_selector, dict) else None
runner_configs = ai_config.get('runner_config', {}) if isinstance(ai_config, dict) else {}
runner_config = runner_configs.get(runner_id, {}) if isinstance(runner_configs, dict) else {}
configured_template = (
runner_config.get('box-session-id-template') if isinstance(runner_config, dict) else None
)
if isinstance(configured_template, str) and configured_template:
template = configured_template
variables = dict(query.variables or {})
launcher_type = getattr(query, 'launcher_type', None)
if hasattr(launcher_type, 'value'):

View File

@@ -1,7 +1,7 @@
"""Normalize AgentRunner config containers
Revision ID: 0004_migrate_runner_config
Revises: 0003_add_rerank_models
Revision ID: 0005_migrate_runner_config
Revises: 0004_add_mcp_readme
Create Date: 2026-05-10
"""
@@ -11,8 +11,8 @@ from alembic import op
from langbot.pkg.agent.runner.config_migration import ConfigMigration
revision = '0004_migrate_runner_config'
down_revision = '0003_add_rerank_models'
revision = '0005_migrate_runner_config'
down_revision = '0004_add_mcp_readme'
branch_labels = None
depends_on = None

View File

@@ -1,7 +1,7 @@
"""add_event_log_and_transcript_tables
Revision ID: 58846a8d7a81
Revises: 0004_migrate_runner_config
Revises: 0005_migrate_runner_config
Create Date: 2026-05-23 15:41:47.030841
"""
from alembic import op
@@ -9,7 +9,7 @@ import sqlalchemy as sa
# revision identifiers
revision = '58846a8d7a81'
down_revision = '0004_migrate_runner_config'
down_revision = '0005_migrate_runner_config'
branch_labels = None
depends_on = None

View File

@@ -117,6 +117,9 @@ class PreProcessor(stage.PipelineStage):
self,
runner_id: str | None,
conversation_uuid: str | None,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
) -> list[provider_message.Message] | None:
if not runner_id or not conversation_uuid or not self._has_declared_db_engine():
return None
@@ -125,7 +128,13 @@ class PreProcessor(stage.PipelineStage):
from ...agent.runner.transcript_store import TranscriptStore
store = TranscriptStore(self.ap.persistence_mgr.get_db_engine())
messages = await store.get_legacy_provider_messages(str(conversation_uuid))
messages = await store.get_legacy_provider_messages(
str(conversation_uuid),
bot_id=bot_id,
workspace_id=workspace_id,
thread_id=thread_id,
strict_thread=True,
)
except Exception as e:
self.ap.logger.warning(
f'Unable to load Transcript history view for conversation {conversation_uuid}: {e}'
@@ -138,10 +147,15 @@ class PreProcessor(stage.PipelineStage):
self,
runner_id: str | None,
conversation: typing.Any,
bot_id: str | None = None,
workspace_id: str | None = None,
) -> list[provider_message.Message]:
transcript_messages = await self._load_agent_runner_history_messages(
runner_id,
getattr(conversation, 'uuid', None),
bot_id=bot_id,
workspace_id=workspace_id,
thread_id=getattr(conversation, 'thread_id', None),
)
if transcript_messages is not None:
return transcript_messages
@@ -213,7 +227,11 @@ class PreProcessor(stage.PipelineStage):
# Attach resolved session state to the query.
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = await self._resolve_history_messages(runner_id, conversation)
query.messages = await self._resolve_history_messages(
runner_id,
conversation,
bot_id=query.bot_uuid,
)
if uses_host_models:
query.use_funcs = []

View File

@@ -120,7 +120,8 @@ def _validate_artifact_access(
Args:
session: AgentRunSession dict with run_id and authorization snapshot
artifact_metadata: Artifact metadata dict with conversation_id, run_id
artifact_metadata: Artifact metadata dict with conversation_id, run_id,
and Host-only scope fields when available
operation: Operation name for error messages ('metadata' or 'read')
Returns:
@@ -138,7 +139,7 @@ def _validate_artifact_access(
# Rule 2: Same conversation (requires artifact to have conversation_id)
if artifact_conversation_id and session_conversation_id:
if artifact_conversation_id == session_conversation_id:
if artifact_conversation_id == session_conversation_id and _artifact_matches_run_scope(session, artifact_metadata):
return True, None
# Rule 3: Deny - no matching authorization rule
@@ -150,6 +151,21 @@ def _get_run_authorization(session: dict[str, Any]) -> dict[str, Any]:
return session['authorization']
def _artifact_matches_run_scope(session: dict[str, Any], artifact_metadata: dict[str, Any]) -> bool:
authorization = _get_run_authorization(session)
for scope_key in ('bot_id', 'workspace_id', 'thread_id'):
if authorization.get(scope_key) != artifact_metadata.get(scope_key):
return False
return True
def _public_artifact_metadata(artifact_metadata: dict[str, Any]) -> dict[str, Any]:
public_metadata = dict(artifact_metadata)
for scope_key in ('bot_id', 'workspace_id', 'thread_id'):
public_metadata.pop(scope_key, None)
return public_metadata
def _resolve_state_scope(
session: dict[str, Any],
scope: str,
@@ -1864,7 +1880,7 @@ class RuntimeConnectionHandler(handler.Handler):
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
try:
metadata = await store.get_metadata(artifact_id)
metadata = await store.get_authorization_metadata(artifact_id)
if not metadata:
return handler.ActionResponse.error(
message=f'Artifact {artifact_id} not found'
@@ -1875,7 +1891,7 @@ class RuntimeConnectionHandler(handler.Handler):
if not is_allowed:
return handler.ActionResponse.error(message=error_msg)
return handler.ActionResponse.success(data=metadata)
return handler.ActionResponse.success(data=_public_artifact_metadata(metadata))
except Exception as e:
self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Artifact metadata error: {e}')
@@ -1934,7 +1950,7 @@ class RuntimeConnectionHandler(handler.Handler):
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
try:
metadata = await store.get_metadata(artifact_id)
metadata = await store.get_authorization_metadata(artifact_id)
if not metadata:
return handler.ActionResponse.error(
message=f'Artifact {artifact_id} not found'

View File

@@ -1,511 +0,0 @@
"""DeerFlow LangGraph API Runner
参考 astrbot 的 deerflow_agent_runner 实现,适配 LangBot 的 Runner 接口。
特点:
- 使用 LangGraph HTTP API 接入 deer-flow 后端
- 自动管理 thread_id按 session 隔离)
- 支持 SSE 流式响应解析
- 支持 streaming/非流式两种输出
- 处理 values / messages-tuple / custom 三种事件
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import typing
from collections import deque
from dataclasses import dataclass, field
from langbot.pkg.provider import runner
from langbot.pkg.core import app
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot.libs.deerflow_api import client, errors, stream_utils
_MAX_VALUES_HISTORY = 200
@dataclass
class _StreamState:
"""流式状态跟踪"""
latest_text: str = ''
prev_text_for_streaming: str = ''
clarification_text: str = ''
task_failures: list[str] = field(default_factory=list)
seen_message_ids: set[str] = field(default_factory=set)
seen_message_order: deque[str] = field(default_factory=deque)
no_id_message_fingerprints: dict[int, str] = field(default_factory=dict)
baseline_initialized: bool = False
has_values_text: bool = False
run_values_messages: list[dict[str, typing.Any]] = field(default_factory=list)
timed_out: bool = False
@runner.runner_class('deerflow-api')
class DeerFlowAPIRunner(runner.RequestRunner):
"""DeerFlow LangGraph API 对话请求器"""
deerflow_client: client.AsyncDeerFlowClient
def __init__(self, ap: app.Application, pipeline_config: dict):
super().__init__(ap, pipeline_config)
cfg = self.pipeline_config['ai']['deerflow-api']
api_base = cfg.get('api-base', '').strip()
if not api_base or not api_base.startswith(('http://', 'https://')):
raise errors.DeerFlowAPIError(
message='DeerFlow API Base URL 格式错误,必须以 http:// 或 https:// 开头',
)
self.api_base = api_base
self.api_key = cfg.get('api-key', '')
self.auth_header = cfg.get('auth-header', '')
self.assistant_id = cfg.get('assistant-id', 'lead_agent')
self.model_name = cfg.get('model-name', '')
self.thinking_enabled = bool(cfg.get('thinking-enabled', False))
self.plan_mode = bool(cfg.get('plan-mode', False))
self.subagent_enabled = bool(cfg.get('subagent-enabled', False))
self.max_concurrent_subagents = int(cfg.get('max-concurrent-subagents', 3))
self.timeout = int(cfg.get('timeout', 300))
self.recursion_limit = int(cfg.get('recursion-limit', 1000))
self.deerflow_client = client.AsyncDeerFlowClient(
api_base=self.api_base,
api_key=self.api_key,
auth_header=self.auth_header,
)
# ------------------------------------------------------------------
# 辅助方法
# ------------------------------------------------------------------
def _fingerprint_message(self, message: dict[str, typing.Any]) -> str:
try:
raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str)
except (TypeError, ValueError):
raw = repr(message)
return hashlib.sha1(raw.encode('utf-8', errors='ignore')).hexdigest()
def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None:
if not msg_id or msg_id in state.seen_message_ids:
return
state.seen_message_ids.add(msg_id)
state.seen_message_order.append(msg_id)
while len(state.seen_message_order) > _MAX_VALUES_HISTORY:
dropped = state.seen_message_order.popleft()
state.seen_message_ids.discard(dropped)
def _extract_new_messages_from_values(
self,
values_messages: list[typing.Any],
state: _StreamState,
) -> list[dict[str, typing.Any]]:
new_messages: list[dict[str, typing.Any]] = []
no_id_indexes_seen: set[int] = set()
for idx, msg in enumerate(values_messages):
if not isinstance(msg, dict):
continue
msg_id = stream_utils.get_message_id(msg)
if msg_id:
if msg_id in state.seen_message_ids:
continue
self._remember_seen_message_id(state, msg_id)
new_messages.append(msg)
continue
no_id_indexes_seen.add(idx)
fp = self._fingerprint_message(msg)
if state.no_id_message_fingerprints.get(idx) == fp:
continue
state.no_id_message_fingerprints[idx] = fp
new_messages.append(msg)
for idx in list(state.no_id_message_fingerprints.keys()):
if idx not in no_id_indexes_seen:
state.no_id_message_fingerprints.pop(idx, None)
return new_messages
# ------------------------------------------------------------------
# 用户输入处理
# ------------------------------------------------------------------
def _build_user_content(
self,
prompt: str,
image_urls: list[str],
) -> typing.Any:
"""构建 LangGraph 兼容的 user content支持多模态"""
if not image_urls:
return prompt
content: list[dict[str, typing.Any]] = []
if prompt:
content.append({'type': 'text', 'text': prompt})
for url in image_urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url:
continue
if url.startswith(('http://', 'https://', 'data:')):
content.append({'type': 'image_url', 'image_url': {'url': url}})
return content if content else prompt
def _preprocess_user_message(
self,
query: pipeline_query.Query,
) -> tuple[str, list[str]]:
"""提取用户消息的纯文本与图片 URL 列表"""
plain_text = ''
image_urls: list[str] = []
if isinstance(query.user_message.content, str):
plain_text = query.user_message.content
elif isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
elif ce.type == 'image_base64':
# 转换为 data URI 形式
b64 = getattr(ce, 'image_base64', '')
if b64:
if not b64.startswith('data:'):
b64 = f'data:image/png;base64,{b64}'
image_urls.append(b64)
elif ce.type == 'image_url':
url = getattr(ce, 'image_url', '')
if url:
image_urls.append(url)
return plain_text, image_urls
# ------------------------------------------------------------------
# 请求构造
# ------------------------------------------------------------------
def _build_messages(
self,
prompt: str,
image_urls: list[str],
system_prompt: str = '',
) -> list[dict[str, typing.Any]]:
messages: list[dict[str, typing.Any]] = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append(
{
'role': 'user',
'content': self._build_user_content(prompt, image_urls),
}
)
return messages
def _build_runtime_configurable(self, thread_id: str) -> dict[str, typing.Any]:
cfg: dict[str, typing.Any] = {
'thread_id': thread_id,
'thinking_enabled': self.thinking_enabled,
'is_plan_mode': self.plan_mode,
'subagent_enabled': self.subagent_enabled,
}
if self.subagent_enabled:
cfg['max_concurrent_subagents'] = self.max_concurrent_subagents
if self.model_name:
cfg['model_name'] = self.model_name
return cfg
def _build_payload(
self,
thread_id: str,
prompt: str,
image_urls: list[str],
system_prompt: str = '',
) -> dict[str, typing.Any]:
runtime_configurable = self._build_runtime_configurable(thread_id)
return {
'assistant_id': self.assistant_id,
'input': {
'messages': self._build_messages(prompt, image_urls, system_prompt),
},
'stream_mode': ['values', 'messages-tuple', 'custom'],
# DeerFlow 2.0 从 config.configurable 读取运行时覆盖
# 同时保留 context 字段做向后兼容
'context': dict(runtime_configurable),
'config': {
'recursion_limit': self.recursion_limit,
'configurable': runtime_configurable,
},
}
# ------------------------------------------------------------------
# Session/Thread 管理
# ------------------------------------------------------------------
async def _ensure_thread_id(self, query: pipeline_query.Query) -> str:
"""从 query.session 取/创建 deerflow thread_id
LangBot 使用 `query.session.using_conversation.uuid` 持久化 conversation id
我们复用这个字段存储 deerflow thread_id与 Dify Runner 同样做法)。
"""
thread_id = query.session.using_conversation.uuid or ''
if thread_id:
return thread_id
thread = await self.deerflow_client.create_thread(timeout=min(30, self.timeout))
thread_id = thread.get('thread_id', '')
if not thread_id:
raise errors.DeerFlowAPIError(message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}')
query.session.using_conversation.uuid = thread_id
return thread_id
# ------------------------------------------------------------------
# 流式事件处理
# ------------------------------------------------------------------
def _handle_values_event(
self,
data: typing.Any,
state: _StreamState,
) -> str | None:
"""处理 values 事件,返回新的完整文本(增量基础上的全量)"""
values_messages = stream_utils.extract_messages_from_values_data(data)
if not values_messages:
return None
new_messages: list[dict[str, typing.Any]] = []
if not state.baseline_initialized:
state.baseline_initialized = True
for idx, msg in enumerate(values_messages):
if not isinstance(msg, dict):
continue
new_messages.append(msg)
msg_id = stream_utils.get_message_id(msg)
if msg_id:
self._remember_seen_message_id(state, msg_id)
continue
state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg)
else:
new_messages = self._extract_new_messages_from_values(values_messages, state)
latest_text = ''
if new_messages:
state.run_values_messages.extend(new_messages)
if len(state.run_values_messages) > _MAX_VALUES_HISTORY:
state.run_values_messages = state.run_values_messages[-_MAX_VALUES_HISTORY:]
latest_text = stream_utils.extract_latest_ai_text(state.run_values_messages)
if latest_text:
state.has_values_text = True
latest_clarification = stream_utils.extract_latest_clarification_text(
state.run_values_messages,
)
if latest_clarification:
state.clarification_text = latest_clarification
return latest_text or None
def _handle_message_event(
self,
data: typing.Any,
state: _StreamState,
) -> str | None:
"""处理 messages-tuple 事件,返回增量文本
当 values 事件已经提供完整文本时,跳过 messages-tuple 的增量
"""
delta = stream_utils.extract_ai_delta_from_event_data(data)
if delta and not state.has_values_text:
state.latest_text += delta
return delta
maybe_clar = stream_utils.extract_clarification_from_event_data(data)
if maybe_clar:
state.clarification_text = maybe_clar
return None
def _build_final_text(self, state: _StreamState) -> str:
"""构建最终输出文本"""
if state.clarification_text:
return state.clarification_text
# 优先使用最后一条 AI message 的文本
latest_ai = stream_utils.extract_latest_ai_message(state.run_values_messages)
if latest_ai:
text = stream_utils.extract_text(latest_ai.get('content'))
if text:
if state.timed_out:
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
if state.latest_text:
text = state.latest_text
if state.timed_out:
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
# 提取任务失败信息作兜底
failure_text = stream_utils.build_task_failure_summary(state.task_failures)
if failure_text:
return failure_text
return 'DeerFlow 返回空响应'
# ------------------------------------------------------------------
# 主流程
# ------------------------------------------------------------------
async def _stream_messages_chunk(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""流式输出生成器"""
plain_text, image_urls = self._preprocess_user_message(query)
system_prompt = ''
# LangBot 的 pipeline 通常通过 prompt-preprocess 已注入 system prompt
# 这里保持空,让 prompt-preprocess 的内容作为 user message 一并送给 deerflow
thread_id = await self._ensure_thread_id(query)
payload = self._build_payload(
thread_id=thread_id,
prompt=plain_text or 'continue',
image_urls=image_urls,
system_prompt=system_prompt,
)
state = _StreamState()
prev_text = ''
message_idx = 0
try:
async for event in self.deerflow_client.stream_run(
thread_id=thread_id,
payload=payload,
timeout=self.timeout,
):
event_type = event.get('event')
data = event.get('data')
if event_type == 'values':
new_full = self._handle_values_event(data, state)
if new_full and new_full != prev_text:
delta = new_full[len(prev_text) :] if new_full.startswith(prev_text) else new_full
prev_text = new_full
if delta:
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
content=new_full,
is_final=False,
)
continue
if event_type in {'messages-tuple', 'messages', 'message'}:
delta = self._handle_message_event(data, state)
if delta:
prev_text = state.latest_text
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
content=prev_text,
is_final=False,
)
continue
if event_type == 'custom':
state.task_failures.extend(
stream_utils.extract_task_failures_from_custom_event(data),
)
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
# 最终消息
final_text = self._build_final_text(state)
yield provider_message.MessageChunk(
role='assistant',
content=final_text,
is_final=True,
)
async def _messages(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""非流式聚合输出"""
plain_text, image_urls = self._preprocess_user_message(query)
thread_id = await self._ensure_thread_id(query)
payload = self._build_payload(
thread_id=thread_id,
prompt=plain_text or 'continue',
image_urls=image_urls,
)
state = _StreamState()
try:
async for event in self.deerflow_client.stream_run(
thread_id=thread_id,
payload=payload,
timeout=self.timeout,
):
event_type = event.get('event')
data = event.get('data')
if event_type == 'values':
self._handle_values_event(data, state)
continue
if event_type in {'messages-tuple', 'messages', 'message'}:
self._handle_message_event(data, state)
continue
if event_type == 'custom':
state.task_failures.extend(
stream_utils.extract_task_failures_from_custom_event(data),
)
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
final_text = self._build_final_text(state)
yield provider_message.Message(
role='assistant',
content=final_text,
)
async def run(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""主入口:根据 adapter 是否支持流式输出,选择流式或非流式"""
if await query.adapter.is_stream_output_supported():
msg_idx = 0
async for msg in self._stream_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
else:
async for msg in self._messages(query):
yield msg

View File

@@ -1,351 +0,0 @@
from __future__ import annotations
import typing
import json
from langbot.pkg.provider import runner
from langbot.pkg.core import app
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot.libs.weknora_api import client, errors
@runner.runner_class('weknora-api')
class WeKnoraAPIRunner(runner.RequestRunner):
"""WeKnora API 对话请求器"""
weknora_client: client.AsyncWeKnoraClient
def __init__(self, ap: app.Application, pipeline_config: dict):
super().__init__(ap, pipeline_config)
valid_app_types = ['chat', 'agent']
if self.pipeline_config['ai']['weknora-api']['app-type'] not in valid_app_types:
raise errors.WeKnoraAPIError(
f'不支持的 WeKnora 应用类型: {self.pipeline_config["ai"]["weknora-api"]["app-type"]}'
)
api_key = self.pipeline_config['ai']['weknora-api'].get('api-key', '').strip()
if not api_key:
raise errors.WeKnoraAPIError(
'WeKnora API Key 未配置,请在流水线的 WeKnora API 配置中填入 API Key '
'(从 WeKnora 前端 设置 → API Keys 生成)'
)
base_url = self.pipeline_config['ai']['weknora-api'].get('base-url', '').strip()
if not base_url:
raise errors.WeKnoraAPIError('WeKnora Base URL 未配置,请填入服务器地址,例如 http://localhost:8080/api/v1')
self.weknora_client = client.AsyncWeKnoraClient(
api_key=api_key,
base_url=base_url,
)
async def _extract_plain_text(self, query: pipeline_query.Query) -> str:
"""从用户消息中提取纯文本内容"""
plain_text = ''
if isinstance(query.user_message.content, str):
plain_text = query.user_message.content
elif isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
if not plain_text:
plain_text = self.pipeline_config['ai']['weknora-api'].get('base-prompt', '')
return plain_text
async def _ensure_session(self, query: pipeline_query.Query) -> str:
"""确保会话存在,如果不存在则创建"""
session_id = query.session.using_conversation.uuid or ''
if not session_id:
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
session_id = await self.weknora_client.create_session(title=f'IM Chat - {user_tag}')
query.session.using_conversation.uuid = session_id
return session_id
async def _agent_chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用 Agent 智能对话(非流式聚合输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
knowledge_base_ids = config.get('knowledge-base-ids', [])
web_search_enabled = config.get('web-search-enabled', False)
timeout = config.get('timeout', 120)
full_answer = ''
chunk = None
async for chunk in self.weknora_client.agent_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
web_search_enabled=web_search_enabled,
timeout=timeout,
):
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
if response_type == 'tool_call':
# 工具调用
tool_data = chunk.get('data', {})
tool_name = tool_data.get('tool_name', '')
if tool_name:
yield provider_message.Message(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id=chunk.get('id', ''),
type='function',
function=provider_message.FunctionCall(
name=tool_name,
arguments=json.dumps(tool_data.get('arguments', {})),
),
)
],
)
elif response_type == 'answer':
if content:
full_answer += content
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if full_answer:
yield provider_message.Message(
role='assistant',
content=full_answer,
)
async def _chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用知识库 RAG 问答(非流式聚合输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-quick-answer')
knowledge_base_ids = config.get('knowledge-base-ids', [])
timeout = config.get('timeout', 120)
full_answer = ''
chunk = None
async for chunk in self.weknora_client.knowledge_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
timeout=timeout,
):
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
if response_type == 'answer':
if content:
full_answer += content
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if full_answer:
yield provider_message.Message(
role='assistant',
content=full_answer,
)
async def _agent_chat_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用 Agent 智能对话(流式输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
knowledge_base_ids = config.get('knowledge-base-ids', [])
web_search_enabled = config.get('web-search-enabled', False)
timeout = config.get('timeout', 120)
pending_answer = ''
message_idx = 0
is_final = False
chunk = None
async for chunk in self.weknora_client.agent_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
web_search_enabled=web_search_enabled,
timeout=timeout,
):
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
done = chunk.get('done', False)
if response_type == 'tool_call':
tool_data = chunk.get('data', {})
tool_name = tool_data.get('tool_name', '')
if tool_name:
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id=chunk.get('id', ''),
type='function',
function=provider_message.FunctionCall(
name=tool_name,
arguments=json.dumps(tool_data.get('arguments', {})),
),
)
],
)
elif response_type == 'answer':
message_idx += 1
if content:
pending_answer += content
if done:
is_final = True
# 每 8 个 chunk 输出一次,或最终输出
if message_idx % 8 == 0 or is_final:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=is_final,
)
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
# 确保最终消息已发出
if not is_final and pending_answer:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=True,
)
async def _chat_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用知识库 RAG 问答(流式输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-quick-answer')
knowledge_base_ids = config.get('knowledge-base-ids', [])
timeout = config.get('timeout', 120)
pending_answer = ''
message_idx = 0
is_final = False
chunk = None
async for chunk in self.weknora_client.knowledge_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
timeout=timeout,
):
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
done = chunk.get('done', False)
if response_type == 'answer':
message_idx += 1
if content:
pending_answer += content
if done:
is_final = True
if message_idx % 8 == 0 or is_final:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=is_final,
)
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if not is_final and pending_answer:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=True,
)
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
app_type = self.pipeline_config['ai']['weknora-api']['app-type']
if await query.adapter.is_stream_output_supported():
msg_idx = 0
if app_type == 'agent':
async for msg in self._agent_chat_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
elif app_type == 'chat':
async for msg in self._chat_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
else:
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')
else:
if app_type == 'agent':
async for msg in self._agent_chat_messages(query):
yield msg
elif app_type == 'chat':
async for msg in self._chat_messages(query):
yield msg
else:
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')

View File

@@ -1038,6 +1038,8 @@ else:
run_id=run_session.get('run_id') if isinstance(run_session, dict) else None,
runner_id=run_session.get('runner_id') if isinstance(run_session, dict) else None,
bot_id=getattr(query, 'bot_uuid', None),
workspace_id=authorization.get('workspace_id'),
thread_id=authorization.get('thread_id'),
metadata=metadata,
)
artifact_ref = {

View File

@@ -43,6 +43,9 @@ def make_session(
plugin_identity: str = 'test/test-runner',
resources: dict | None = None,
conversation_id: str | None = None,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
available_apis: dict[str, bool] | None = None,
state_policy: dict[str, typing.Any] | None = None,
state_context: dict[str, typing.Any] | None = None,
@@ -114,6 +117,9 @@ def make_session(
'resources': res,
'available_apis': apis,
'conversation_id': conversation_id,
'bot_id': bot_id,
'workspace_id': workspace_id,
'thread_id': thread_id,
'state_policy': policy,
'state_context': context,
'authorized_ids': authorized_ids,

View File

@@ -208,10 +208,20 @@ class TestArtifactAuthorization:
class TestArtifactAccessValidation:
"""Test _validate_artifact_access authorization rules."""
def _make_session(self, conversation_id: str | None):
def _make_session(
self,
conversation_id: str | None,
*,
bot_id: str | None = None,
workspace_id: str | None = None,
thread_id: str | None = None,
):
return make_session(
run_id="run_001",
conversation_id=conversation_id,
bot_id=bot_id,
workspace_id=workspace_id,
thread_id=thread_id,
available_apis={"artifact_metadata": True, "artifact_read": True},
)
@@ -259,6 +269,64 @@ class TestArtifactAccessValidation:
assert is_allowed is True
assert error is None
def test_same_conversation_and_scope_allowed(self):
"""Artifacts in the same run scope are allowed across runs."""
session = self._make_session(
"conv_001",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_001",
)
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001",
"run_id": "run_other",
"bot_id": "bot_001",
"workspace_id": "workspace_001",
"thread_id": "thread_001",
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is True
assert error is None
def test_same_conversation_different_scope_denied(self):
"""Artifacts in another bot/thread scope are denied even in the same conversation."""
session = self._make_session(
"conv_001",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_001",
)
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001",
"run_id": "run_other",
"bot_id": "bot_002",
"workspace_id": "workspace_001",
"thread_id": "thread_001",
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is False
assert "denied" in error.lower()
def test_same_conversation_missing_scope_denied_for_scoped_session(self):
"""Scoped runs should not read legacy-scope artifacts from other runs."""
session = self._make_session("conv_001", bot_id="bot_001")
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001",
"run_id": "run_other",
"bot_id": None,
"workspace_id": None,
"thread_id": None,
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is False
assert "denied" in error.lower()
def test_different_conversation_and_run_denied(self):
"""Artifacts in different conversation and different run are denied."""
session = self._make_session("conv_001")
@@ -470,6 +538,9 @@ class TestArtifactStoreRealSQLite:
content=content,
conversation_id="conv_001",
run_id="run_001",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_001",
)
assert artifact_id == "art_real_001"
@@ -489,6 +560,14 @@ class TestArtifactStoreRealSQLite:
assert "storage_type" not in metadata
assert "bot_id" not in metadata
assert "workspace_id" not in metadata
assert "thread_id" not in metadata
assert "_langbot_thread_id" not in metadata.get("metadata", {})
auth_metadata = await store.get_authorization_metadata(artifact_id)
assert auth_metadata is not None
assert auth_metadata["bot_id"] == "bot_001"
assert auth_metadata["workspace_id"] == "workspace_001"
assert auth_metadata["thread_id"] == "thread_001"
@pytest.mark.asyncio
async def test_read_artifact_round_trip(self, db_engine):

View File

@@ -628,6 +628,52 @@ class TestTranscriptStoreRealSQLite:
assert messages[0].content[0].text == "User structured text"
assert messages[1].content == "Assistant text"
@pytest.mark.asyncio
async def test_get_legacy_provider_messages_filters_scope(self, db_engine):
"""Legacy Pipeline history projection must stay inside the current run scope."""
store = TranscriptStore(db_engine)
await store.append_transcript(
transcript_id="trans_scope_001",
event_id="evt_scope_001",
conversation_id="conv_scope",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_001",
role="user",
content="Current scope text",
)
await store.append_transcript(
transcript_id="trans_scope_002",
event_id="evt_scope_002",
conversation_id="conv_scope",
bot_id="bot_002",
workspace_id="workspace_001",
thread_id="thread_001",
role="assistant",
content="Other bot text",
)
await store.append_transcript(
transcript_id="trans_scope_003",
event_id="evt_scope_003",
conversation_id="conv_scope",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_002",
role="assistant",
content="Other thread text",
)
messages = await store.get_legacy_provider_messages(
"conv_scope",
bot_id="bot_001",
workspace_id="workspace_001",
thread_id="thread_001",
strict_thread=True,
)
assert [message.content for message in messages] == ["Current scope text"]
@pytest.mark.asyncio
async def test_search_transcript_real_db(self, db_engine):
"""Test search_transcript with real DB."""

View File

@@ -193,6 +193,41 @@ async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
]
@pytest.mark.asyncio
async def test_build_resources_accepts_dynamic_form_type_aliases(app):
"""Frontend DynamicForm aliases should resolve to runtime resource grants."""
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
async def get_kb(kb_uuid):
return SimpleNamespace(
uuid=kb_uuid,
get_name=lambda: f'name-{kb_uuid}',
knowledge_base_entity=SimpleNamespace(kb_type='default'),
)
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
descriptor = make_descriptor(
capabilities={'knowledge_retrieval': True},
config_schema=[
{'name': 'model', 'type': 'select-llm-model'},
{'name': 'knowledge-bases', 'type': 'select-knowledge-bases'},
],
)
query = make_query({
'model': 'llm_alias',
'knowledge-bases': ['kb_alias'],
})
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
]
assert resources['knowledge_bases'] == [
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
]
@pytest.mark.asyncio
async def test_build_models_manifest_permission_narrows_binding(app):
"""Manifest model permissions narrower than binding should remove LLM grants."""

View File

@@ -288,6 +288,45 @@ class TestSessionRegistryBasic:
assert len(items) == MAX_STEERING_QUEUE_ITEMS
assert all(item['event']['event_id'] != 'overflow' for item in items)
@pytest.mark.asyncio
async def test_find_steering_target_requires_same_scope(self):
"""Steering claims must not cross bot/workspace/thread boundaries."""
registry = AgentRunSessionRegistry()
await registry.register(
run_id='run_steering_scoped',
runner_id='plugin:test/my-runner/default',
query_id=1,
plugin_identity='test/my-runner',
resources=make_resources(),
conversation_id='conv_1',
bot_id='bot_1',
workspace_id='workspace_1',
thread_id='thread_1',
available_apis={'steering_pull': True},
)
assert await registry.find_steering_target(
conversation_id='conv_1',
runner_id='plugin:test/my-runner/default',
bot_id='bot_1',
workspace_id='workspace_1',
thread_id='thread_1',
) == 'run_steering_scoped'
assert await registry.find_steering_target(
conversation_id='conv_1',
runner_id='plugin:test/my-runner/default',
bot_id='bot_2',
workspace_id='workspace_1',
thread_id='thread_1',
) is None
assert await registry.find_steering_target(
conversation_id='conv_1',
runner_id='plugin:test/my-runner/default',
bot_id='bot_1',
workspace_id='workspace_1',
thread_id='thread_2',
) is None
@pytest.mark.asyncio
async def test_unregister_returns_pending_steering_queue(self):
"""Unregister returns the removed session so callers can audit pending steering."""

View File

@@ -630,7 +630,9 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
mock_session = MagicMock()
mock_session.server_name = 'transient-test-server'
mock_session.start = AsyncMock()
mock_session.shutdown = AsyncMock()
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
@@ -645,4 +647,9 @@ class TestMCPServiceTestMCPServer:
# Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456
assert task_id == 456
coroutine = ap.task_mgr.create_user_task.call_args.args[0]
await coroutine
mock_session.start.assert_awaited_once()
mock_session.shutdown.assert_awaited_once()

View File

@@ -181,6 +181,23 @@ def make_app(
)
def test_resolve_box_session_id_reads_current_runner_config():
query = make_query(101)
query.pipeline_config = {
'ai': {
'runner': {'id': 'plugin:langbot/local-agent/default'},
'runner_config': {
'plugin:langbot/local-agent/default': {
'box-session-id-template': 'bot-{launcher_id}-{sender_id}',
},
},
},
}
service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient))
assert service.resolve_box_session_id(query) == 'bot-test_user-test_user'
@pytest.mark.asyncio
async def test_box_service_without_explicit_client_initializes_internal_connector(monkeypatch: pytest.MonkeyPatch):
connector = Mock()

View File

@@ -273,6 +273,13 @@ async def test_preproc_uses_transcript_history_view_when_available():
assert result.result_type == entities_module.ResultType.CONTINUE
assert query.messages == transcript_messages
stage._load_agent_runner_history_messages.assert_awaited_once_with(
'plugin:langbot/local-agent/default',
'conv-1',
bot_id='bot-1',
workspace_id=None,
thread_id=None,
)
@pytest.mark.asyncio