fix(agent-runner): harden run lifecycle and protocol stores

This commit is contained in:
huanghuoguoguo
2026-06-13 21:22:13 +08:00
parent 735a0011b0
commit 1153433693
16 changed files with 450 additions and 103 deletions

View File

@@ -324,6 +324,7 @@ class InlineContextPolicy(BaseModel):
reason: str | None = None
class ContextAPICapabilities(BaseModel):
prompt_get: bool = False
history_page: bool = False
history_search: bool = False
event_get: bool = False
@@ -473,8 +474,8 @@ Host 必须校验 `state.updated` 的 scope、key、value 大小和 JSON 可序
```python
# Model
await api.invoke_llm(model_id, messages, funcs=None, extra_args=None)
async for chunk in api.invoke_llm_stream(model_id, messages, funcs=None, extra_args=None):
await api.invoke_llm(llm_model_uuid, messages, funcs=None, extra_args=None)
async for chunk in api.invoke_llm_stream(llm_model_uuid, messages, funcs=None, extra_args=None):
...
await api.invoke_rerank(rerank_model_id, query, documents, top_k=None)
@@ -486,13 +487,14 @@ await api.call_tool(tool_name, parameters)
await api.retrieve_knowledge(kb_id, query_text, top_k=5, filters=None)
# History返回 Transcript projection不返回原始平台 payload
await api.get_prompt()
await api.history_page(conversation_id=None, before_cursor=None, after_cursor=None,
limit=50, direction="backward", include_artifacts=False)
await api.history_search(query, filters=None, top_k=10)
# Event返回稳定 event envelope 或受限 raw ref不默认返回大 payload
await api.event_get(event_id)
await api.event_page(before_cursor=None, limit=50)
await api.event_page(conversation_id=None, event_types=None, before_cursor=None, limit=50)
await api.steering_pull(mode="all", limit=None)
# Artifact必须支持大小限制、MIME 校验、过期时间和授权范围)
@@ -502,7 +504,7 @@ await api.artifact_read_range(artifact_id, offset=0, length=65536)
# State / Storage
await api.state_get(scope, key); await api.state_set(scope, key, value); await api.state_delete(scope, key)
await api.state_list(scope, prefix=None)
await api.state_list(scope, prefix=None, limit=100)
await api.get_plugin_storage(key); await api.set_plugin_storage(key, value); await api.delete_plugin_storage(key)
await api.get_plugin_storage_keys()
await api.get_workspace_storage(key); await api.set_workspace_storage(key, value); await api.delete_workspace_storage(key)
@@ -513,6 +515,15 @@ await api.get_file(file_key)
await api.get_langbot_version()
```
`invoke_llm()` / `invoke_llm_stream()` 的第一个参数在 SDK 中命名为
`llm_model_uuid`wire payload 字段也是 `llm_model_uuid`。该值对 runner
仍是 opaque identifier不应解析其内部格式。
`get_prompt()` 返回当前 query-backed run 的 Host effective prompt messages
`list[Message]` 的 JSON 形式。该能力只在 `ctx.context.available_apis.prompt_get`
为 true 时可用;没有 query 缓存、prompt 已过期或非 query entry run 时 Host
可以返回错误或空列表。Runner 应在不可用时回退到自己的 config/prompt 策略。
`steering_pull(mode="all")` 是推荐默认Host 按 claim 顺序返回全部 pending steering 输入并清空对应队列。`mode="one-at-a-time"` 仅用于 runner 主动节流每次返回一条。Host 不合并多条用户消息runner 负责在 turn 边界决定模型侧格式。
Steering 审计使用 EventLog 而不是 Transcript schema 扩展:被 active run 吸收的原始 `message.received` 事件保留原事件类型,并在 `metadata.steering` 标记 `status="queued"``trigger_behavior="absorbed_into_active_run"``claimed_by_run_id``claimed_runner_id``claimed_at`。Runner 成功 pull 后Host 追加 `steering.injected` EventLog 记录,`metadata.steering.status="injected"` 并引用 `source_event_id`。若 run 结束时仍有已 claim 但未 pull 的 steering 输入Host 追加 `steering.dropped` EventLog 记录,`metadata.steering.status="dropped"` 并引用 `source_event_id`;这不是用户消息事实的删除,只是 dispatch 终态。Transcript 继续只表示会话事实,不承担 dispatch 行为标记。

View File

@@ -8,6 +8,7 @@ import uuid
import base64
import os
import aiofiles
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker
@@ -17,6 +18,23 @@ from ...entity.persistence.bstorage import BinaryStorage
_FILE_ARTIFACT_METADATA_KEY = '_langbot_file_artifact'
_ARTIFACT_THREAD_METADATA_KEY = '_langbot_thread_id'
UTC = datetime.timezone.utc
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(UTC)
def _as_utc(value: datetime.datetime) -> datetime.datetime:
if value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value.astimezone(UTC)
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
if value is None:
return None
return int(_as_utc(value).timestamp())
class ArtifactStore:
@@ -191,7 +209,7 @@ class ArtifactStore:
runner_id=runner_id,
bot_id=bot_id,
workspace_id=workspace_id,
created_at=datetime.datetime.utcnow(),
created_at=_utc_now(),
expires_at=expires_at,
metadata_json=json.dumps(metadata_payload) if metadata_payload else None,
)
@@ -336,7 +354,7 @@ class ArtifactStore:
}
if storage_type == 'file':
return self._read_file_storage(record, artifact_id, offset, limit)
return await self._read_file_storage(record, artifact_id, offset, limit)
# For other storage types, return storage reference
# (caller can use file_key for chunked transfer)
@@ -363,7 +381,7 @@ class ArtifactStore:
backing lifecycle remains owned by the storage provider.
"""
if now is None:
now = datetime.datetime.utcnow()
now = _utc_now()
async with self._session_factory() as session:
result = await session.execute(
@@ -416,7 +434,7 @@ class ArtifactStore:
return None
return row.value
def _read_file_storage(
async def _read_file_storage(
self,
record: AgentArtifact,
artifact_id: str,
@@ -441,9 +459,9 @@ class ArtifactStore:
if offset >= file_size:
content = b''
else:
with open(real_path, 'rb') as f:
f.seek(offset)
content = f.read(limit)
async with aiofiles.open(real_path, 'rb') as f:
await f.seek(offset)
content = await f.read(limit)
return {
'artifact_id': artifact_id,
@@ -491,8 +509,8 @@ class ArtifactStore:
if row.expires_at is None:
return False
if now is None:
now = datetime.datetime.utcnow()
return row.expires_at <= now
now = _utc_now()
return _as_utc(row.expires_at) <= _as_utc(now)
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
"""Convert an AgentArtifact row to public dict.
@@ -511,7 +529,7 @@ class ArtifactStore:
'conversation_id': row.conversation_id,
'run_id': row.run_id,
'runner_id': row.runner_id,
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
'created_at': _datetime_to_epoch(row.created_at),
'expires_at': _datetime_to_epoch(row.expires_at),
'metadata': self._public_metadata(row.metadata_json),
}

View File

@@ -22,7 +22,14 @@ class AgentBindingResolver:
event: AgentEventEnvelope,
agents: list[AgentConfig],
) -> AgentBinding:
"""Resolve exactly one enabled Agent for the event."""
"""Resolve exactly one enabled Agent for the event.
Callers that source agents from bot/workspace/global configuration must
pre-filter candidates to the event scope before calling this resolver.
The current AgentConfig model represents one already-selected product
Agent and does not carry enough scope metadata to make that decision
safely here.
"""
matches = [
agent
for agent in agents

View File

@@ -13,6 +13,23 @@ from sqlalchemy.orm import sessionmaker
from ...entity.persistence.event_log import EventLog
UTC = datetime.timezone.utc
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(UTC)
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
else:
value = value.astimezone(UTC)
return int(value.timestamp())
class EventLogStore:
"""Store for EventLog records.
@@ -107,7 +124,7 @@ class EventLogStore:
run_id=run_id,
runner_id=runner_id,
metadata_json=json.dumps(metadata) if metadata else None,
created_at=datetime.datetime.utcnow(),
created_at=_utc_now(),
)
session.add(event)
await session.commit()
@@ -277,7 +294,7 @@ class EventLogStore:
'id': row.id,
'event_id': row.event_id,
'event_type': row.event_type,
'event_time': int(row.event_time.timestamp()) if row.event_time else None,
'event_time': _datetime_to_epoch(row.event_time),
'source': row.source,
'bot_id': row.bot_id,
'workspace_id': row.workspace_id,
@@ -293,6 +310,6 @@ class EventLogStore:
'raw_ref': row.raw_ref,
'run_id': row.run_id,
'runner_id': row.runner_id,
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
'created_at': _datetime_to_epoch(row.created_at),
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
}

View File

@@ -103,40 +103,6 @@ class AgentRunOrchestrator:
state_context = build_state_context(event, binding, descriptor)
run_id = context['run_id']
await self._session_registry.register(
run_id=run_id,
runner_id=descriptor.id,
query_id=session_query_id,
plugin_identity=descriptor.get_plugin_id(),
resources=resources,
available_apis=context.get('context', {}).get('available_apis'),
conversation_id=event.conversation_id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
thread_id=event.thread_id,
state_policy={
'enable_state': binding.state_policy.enable_state,
'state_scopes': list(binding.state_policy.state_scopes),
},
state_context=state_context,
)
event_log_id = await self.journal.write_event_log(
event=event,
binding=binding,
run_id=run_id,
runner_id=descriptor.id,
)
await self.journal.register_input_artifacts(
event=event,
run_id=run_id,
runner_id=descriptor.id,
)
if event.event_type == 'message.received' and event.conversation_id:
await self.journal.write_user_transcript(
event=event,
event_log_id=event_log_id,
)
pending_artifact_refs: list[dict[str, typing.Any]] = []
seen_sequences: set[int] = set()
@@ -144,6 +110,41 @@ class AgentRunOrchestrator:
assistant_transcript_written = False
try:
await self._session_registry.register(
run_id=run_id,
runner_id=descriptor.id,
query_id=session_query_id,
plugin_identity=descriptor.get_plugin_id(),
resources=resources,
available_apis=context.get('context', {}).get('available_apis'),
conversation_id=event.conversation_id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
thread_id=event.thread_id,
state_policy={
'enable_state': binding.state_policy.enable_state,
'state_scopes': list(binding.state_policy.state_scopes),
},
state_context=state_context,
)
event_log_id = await self.journal.write_event_log(
event=event,
binding=binding,
run_id=run_id,
runner_id=descriptor.id,
)
await self.journal.register_input_artifacts(
event=event,
run_id=run_id,
runner_id=descriptor.id,
)
if event.event_type == 'message.received' and event.conversation_id:
await self.journal.write_user_transcript(
event=event,
event_log_id=event_log_id,
)
async for result_dict in self.invoker.invoke(descriptor, context):
sequence = result_dict.get('sequence')
if sequence is not None:

View File

@@ -27,6 +27,7 @@ from .host_models import (
StatePolicy,
DeliveryPolicy,
)
from .config_migration import ConfigMigration
from . import events as runner_events
@@ -42,6 +43,7 @@ class QueryEntryAdapter:
INTERNAL_PREFIX = '_'
SENSITIVE_PATTERNS = ('secret', 'token', 'key', 'password', 'credential', 'api_key', 'apikey')
PERMISSION_VARS = ('_pipeline_bound_plugins', '_authorized', '_permission')
EVENT_DATA_MAX_STRING_BYTES = 512
@classmethod
def query_to_event(
@@ -103,8 +105,7 @@ class QueryEntryAdapter:
) -> AgentConfig:
"""Project the current Pipeline config container into target Agent config."""
pipeline_config = query.pipeline_config or {}
ai_config = pipeline_config.get('ai', {})
runner_config = ai_config.get('runner_config', {}).get(runner_id, {})
runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id)
agent_id = getattr(query, 'pipeline_uuid', None)
# Build resource policy from current config
@@ -199,12 +200,13 @@ class QueryEntryAdapter:
event_data: dict[str, typing.Any] = {}
if message_event and hasattr(message_event, 'model_dump'):
try:
event_data = message_event.model_dump(mode='json')
raw_event_data = message_event.model_dump(mode='json')
except TypeError:
event_data = message_event.model_dump()
raw_event_data = message_event.model_dump()
except Exception:
event_data = {}
event_data.pop('source_platform_object', None)
raw_event_data = {}
if isinstance(raw_event_data, dict):
event_data = cls._compact_event_data(raw_event_data)
source_event_type = None
if message_event:
@@ -231,6 +233,25 @@ class QueryEntryAdapter:
data=event_data,
)
@classmethod
def _compact_event_data(
cls,
event_data: dict[str, typing.Any],
) -> dict[str, typing.Any]:
"""Keep only small scalar source-event metadata in event.data."""
compact: dict[str, typing.Any] = {}
for key, value in event_data.items():
if key == 'source_platform_object' or key.startswith('_'):
continue
if value is None or isinstance(value, (bool, int, float)):
compact[key] = value
continue
if isinstance(value, str):
if len(value.encode('utf-8')) <= cls.EVENT_DATA_MAX_STRING_BYTES:
compact[key] = value
continue
return compact
@classmethod
def _build_scoped_event_id(
cls,
@@ -430,6 +451,18 @@ class QueryEntryAdapter:
import uuid
attachments: list[dict[str, typing.Any]] = []
seen_keys: dict[tuple[str, str, str], set[str]] = {}
def add_attachment(attachment: dict[str, typing.Any]) -> None:
key = cls._attachment_dedupe_key(attachment)
if key is not None:
source = str(attachment.get('source') or '')
sources = seen_keys.setdefault(key, set())
if source and sources and source not in sources:
return
if source:
sources.add(source)
attachments.append(attachment)
for elem in contents:
elem_type = elem.get('type')
@@ -437,21 +470,21 @@ class QueryEntryAdapter:
if elem_type == 'image_url':
image_url = elem.get('image_url') or {}
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'image',
'source': 'url',
'url': image_url.get('url') if isinstance(image_url, dict) else str(image_url),
})
elif elem_type == 'image_base64':
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'image',
'source': 'base64',
'content': elem.get('image_base64'),
})
elif elem_type == 'file_url':
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'file',
'source': 'url',
@@ -459,7 +492,7 @@ class QueryEntryAdapter:
'name': elem.get('file_name'),
})
elif elem_type == 'file_base64':
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'file',
'source': 'base64',
@@ -478,32 +511,56 @@ class QueryEntryAdapter:
artifact_id = str(uuid.uuid4()) # Generate unique ID
if isinstance(component, platform_message.Image):
attachments.append({
image_id = component.image_id or None
image_url = component.url or None
image_base64 = component.base64 or None
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'image',
'source': 'message_chain',
'id': component.image_id or None,
'url': component.url or None,
'id': image_id,
'url': image_url,
'content': image_base64,
})
elif isinstance(component, platform_message.File):
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'file',
'source': 'message_chain',
'id': component.id or None,
'name': component.name or None,
'url': component.url or None,
'content': component.base64 or None,
})
elif isinstance(component, platform_message.Voice):
attachments.append({
add_attachment({
'artifact_id': artifact_id,
'artifact_type': 'voice',
'source': 'message_chain',
'id': component.voice_id or None,
'url': component.url or None,
'content': component.base64 or None,
})
return attachments
@classmethod
def _attachment_dedupe_key(
cls,
attachment: dict[str, typing.Any],
) -> tuple[str, str, str] | None:
"""Return a stable key for the same attachment across content sources."""
artifact_type = attachment.get('artifact_type')
if not artifact_type:
return None
for field in ('id', 'url', 'content'):
value = attachment.get(field)
if value:
if field == 'content':
value = hashlib.sha256(str(value).encode('utf-8')).hexdigest()
return str(artifact_type), field, str(value)
return None
@classmethod
def _build_delivery_context(
cls,

View File

@@ -165,7 +165,11 @@ class AgentRunJournal:
input_json=input_json,
run_id=run_id,
runner_id=runner_id,
event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None,
event_time=(
datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc)
if event.event_time
else None
),
metadata=metadata,
)
@@ -468,7 +472,10 @@ class AgentRunJournal:
raw_event_time = event.get('event_time')
if raw_event_time:
try:
event_time = datetime.datetime.fromtimestamp(raw_event_time)
event_time = datetime.datetime.fromtimestamp(
raw_event_time,
datetime.timezone.utc,
)
except (TypeError, ValueError, OSError):
event_time = None

View File

@@ -122,6 +122,9 @@ class AgentRunSessionRegistry:
state_policy: State policy from binding (enable_state, state_scopes)
state_context: Context for state API (scope_keys, binding_identity, etc.)
"""
if not isinstance(plugin_identity, str) or not plugin_identity.strip():
raise ValueError('plugin_identity is required for agent run sessions')
now = int(time.time())
available_apis = copy.deepcopy(available_apis or {})

View File

@@ -14,6 +14,23 @@ from ...entity.persistence.transcript import Transcript
from langbot_plugin.api.entities.builtin.provider import message as provider_message
UTC = datetime.timezone.utc
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(UTC)
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
else:
value = value.astimezone(UTC)
return int(value.timestamp())
class TranscriptStore:
"""Store for Transcript records.
@@ -94,7 +111,7 @@ class TranscriptStore:
seq=0,
run_id=run_id,
runner_id=runner_id,
created_at=datetime.datetime.utcnow(),
created_at=_utc_now(),
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(item)
@@ -371,7 +388,7 @@ class TranscriptStore:
'content_json': json.loads(row.content_json) if row.content_json else None,
'seq': row.seq,
'cursor': str(row.seq),
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
'created_at': _datetime_to_epoch(row.created_at),
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
}

View File

@@ -55,7 +55,7 @@ class Transcript(Base):
"""Artifact references as JSON string (list of ArtifactRef)."""
# Sequence for cursor-based pagination
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True)
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
"""Monotonic cursor sequence for pagination."""
# Context

View File

@@ -29,6 +29,26 @@ def _load_config(config_value):
return None
def _update_config(conn, table_name: str, pipeline_uuid: str, migrated_config: dict) -> None:
"""Write JSON config using a dialect-compatible bind."""
config_json = json.dumps(migrated_config)
if conn.dialect.name == 'postgresql':
conn.execute(
sa.text(
f'UPDATE {table_name} '
'SET config = CAST(:config AS JSON) '
'WHERE uuid = :uuid'
),
{'config': config_json, 'uuid': pipeline_uuid},
)
return
conn.execute(
sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'),
{'config': config_json, 'uuid': pipeline_uuid},
)
def upgrade() -> None:
"""Normalize existing pipeline config containers."""
conn = op.get_bind()
@@ -56,10 +76,7 @@ def upgrade() -> None:
# Only update if config changed
if json.dumps(config, sort_keys=True) != json.dumps(migrated_config, sort_keys=True):
conn.execute(
sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'),
{'config': json.dumps(migrated_config), 'uuid': pipeline_uuid},
)
_update_config(conn, table_name, pipeline_uuid, migrated_config)
except Exception:
# Skip invalid configs
continue

View File

@@ -211,19 +211,23 @@ async def _validate_agent_run_session(
)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return None, handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity')
return None, handler.ActionResponse.error(
message=f'Run session {run_id} has no plugin_identity'
)
if not caller_plugin_identity:
return None, handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
if api_capability:
available_apis = _get_run_authorization(session).get('available_apis', {})
@@ -384,19 +388,25 @@ async def _validate_run_authorization(
)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return None, handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}',
)
if caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
)
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
ap.logger.warning(
f'{resource_type.upper()}: run_id {run_id} has no plugin_identity'
)
return None, handler.ActionResponse.error(
message=f'Run session {run_id} has no plugin_identity',
)
if not caller_plugin_identity:
return None, handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}',
)
if caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
)
if not session_registry.is_resource_allowed(session, resource_type, resource_id, operation):
ap.logger.warning(

View File

@@ -106,6 +106,28 @@ class TestQueryToEventEnvelope:
"message_id": "source-message-1",
}
def test_query_to_event_keeps_large_payloads_out_of_event_data(self, mock_query):
"""Large or nested platform payloads should not be duplicated into event.data."""
source_event = Mock()
source_event.type = "platform.message.created"
source_event.time = 1700000000
source_event.sender = None
source_event.model_dump = Mock(return_value={
"type": "platform.message.created",
"message_id": "source-message-1",
"message_chain": [{"type": "Image", "base64": "data:image/png;base64," + ("a" * 1024)}],
"raw_text": "x" * 1024,
"source_platform_object": {"large": "payload"},
})
mock_query.message_event = source_event
event = QueryEntryAdapter.query_to_event(mock_query)
assert event.data == {
"type": "platform.message.created",
"message_id": "source-message-1",
}
def test_query_to_event_handles_missing_message_chain(self, mock_query):
"""Test delivery context building when Query has no message_chain."""
delattr(mock_query, "message_chain")
@@ -137,6 +159,29 @@ class TestQueryConfigToAgentConfig:
assert agent_config.runner_id == "plugin:author/plugin/runner"
def test_config_to_agent_config_uses_legacy_runner_config_migration(self, mock_query):
"""Temporary query adapter must share the normal runner config resolver."""
mock_query.pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": "model-primary",
"knowledge-base": "kb-001",
},
}
}
agent_config = QueryEntryAdapter.config_to_agent_config(
mock_query,
"plugin:langbot/local-agent/default",
)
assert agent_config.runner_config["model"] == {
"primary": "model-primary",
"fallbacks": [],
}
assert agent_config.runner_config["knowledge-bases"] == ["kb-001"]
def test_resolver_projects_agent_scope(self, mock_query):
"""Test binding scope projection through the resolver."""
event = QueryEntryAdapter.query_to_event(mock_query)

View File

@@ -2024,6 +2024,75 @@ class TestCallerPluginIdentityValidation:
await registry.unregister('run_no_caller_identity')
@pytest.mark.asyncio
async def test_session_missing_plugin_identity_denied(self):
"""Malformed legacy sessions without plugin_identity fail closed."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
registry = get_session_registry()
resources = make_resources(models=[{'model_id': 'model_001'}])
session = make_session(
run_id='run_missing_session_identity',
runner_id='plugin:test/runner/default',
plugin_identity='',
resources=resources,
)
async with registry._lock:
registry._sessions['run_missing_session_identity'] = session
from langbot.pkg.plugin.handler import _validate_run_authorization
mock_ap = MagicMock()
mock_ap.logger = MagicMock()
session, error = await _validate_run_authorization(
'run_missing_session_identity',
'model',
'model_001',
mock_ap,
caller_plugin_identity='test/runner',
)
assert session is None
assert error is not None
assert 'no plugin_identity' in error.message
await registry.unregister('run_missing_session_identity')
@pytest.mark.asyncio
async def test_pull_api_session_missing_plugin_identity_denied(self):
"""Pull API validation also fails closed for missing session identity."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
registry = get_session_registry()
session = make_session(
run_id='run_missing_pull_identity',
runner_id='plugin:test/runner/default',
plugin_identity='',
available_apis={'history_page': True},
)
async with registry._lock:
registry._sessions['run_missing_pull_identity'] = session
from langbot.pkg.plugin.handler import _validate_agent_run_session
mock_ap = MagicMock()
mock_ap.logger = MagicMock()
session, error = await _validate_agent_run_session(
'run_missing_pull_identity',
'test/runner',
mock_ap,
'HISTORY_PAGE',
'history_page',
)
assert session is None
assert error is not None
assert 'no plugin_identity' in error.message
await registry.unregister('run_missing_pull_identity')
class TestBackwardCompatStorageNoRunId:
"""Tests for unscoped storage actions without run_id.

View File

@@ -275,6 +275,37 @@ def test_context_builder_includes_consumable_base64_attachments():
assert input_data.attachments[1].name == "hello.txt"
def test_context_builder_deduplicates_message_chain_attachments():
query = make_query()
query.user_message = None
query.message_chain = platform_message.MessageChain(
[platform_message.Image(base64="data:image/jpeg;base64,aGVsbG8=")]
)
input_data = QueryEntryAdapter._build_input(query)
assert [content.type for content in input_data.contents] == ["image_base64"]
assert len(input_data.attachments) == 1
assert input_data.attachments[0].artifact_type == "image"
assert input_data.attachments[0].content == "data:image/jpeg;base64,aGVsbG8="
def test_context_builder_preserves_same_source_duplicate_attachments():
query = make_query()
query.user_message = provider_message.Message(
role="user",
content=[
provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="),
provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="),
],
)
query.message_chain = platform_message.MessageChain([])
input_data = QueryEntryAdapter._build_input(query)
assert [attachment.artifact_type for attachment in input_data.attachments] == ["image", "image"]
@pytest.fixture(autouse=True)
async def clean_agent_state():
"""Reset all singleton stores and create a test database engine."""
@@ -546,6 +577,29 @@ async def test_orchestrator_unregisters_session_after_runner_failure(clean_agent
assert await get_session_registry().get(context["run_id"]) is None
@pytest.mark.asyncio
async def test_orchestrator_unregisters_session_after_event_log_failure(clean_agent_state):
"""Journal failures before runner invocation must not leave steerable sessions."""
db_engine = clean_agent_state
descriptor = make_descriptor()
plugin_connector = FakePluginConnector(
results=[
{
"type": "message.completed",
"data": {"message": {"role": "assistant", "content": "unused"}},
}
]
)
orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor))
orchestrator.journal.write_event_log = AsyncMock(side_effect=RuntimeError("journal unavailable"))
with pytest.raises(RuntimeError, match="journal unavailable"):
[message async for message in orchestrator.run_from_query(make_query())]
assert plugin_connector.contexts == []
assert await get_session_registry().list_active_runs() == []
@pytest.mark.asyncio
async def test_orchestrator_enforces_total_runner_deadline(clean_agent_state):
"""Test that orchestrator enforces total runner timeout."""

View File

@@ -49,6 +49,20 @@ class TestSessionRegistryBasic:
assert 'permissions' not in result
assert '_authorized_ids' not in result
@pytest.mark.asyncio
async def test_register_requires_plugin_identity(self):
"""Agent run sessions must always have an owning plugin identity."""
registry = AgentRunSessionRegistry()
with pytest.raises(ValueError, match='plugin_identity is required'):
await registry.register(
run_id='run_missing_identity',
runner_id='plugin:test/my-runner/default',
query_id=1,
plugin_identity='',
resources=make_resources(),
)
@pytest.mark.asyncio
async def test_register_freezes_authorization_snapshot(self):
"""Register should freeze authorization data for the run."""