From 1153433693bcf4fa8868990ee5b97a5c7a5fe970 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <60681390+huanghuoguoguo@users.noreply.github.com> Date: Sat, 13 Jun 2026 21:22:13 +0800 Subject: [PATCH] fix(agent-runner): harden run lifecycle and protocol stores --- .../agent-runner-pluginization/PROTOCOL_V1.md | 19 +++- .../pkg/agent/runner/artifact_store.py | 40 ++++++--- .../pkg/agent/runner/binding_resolver.py | 9 +- .../pkg/agent/runner/event_log_store.py | 23 ++++- src/langbot/pkg/agent/runner/orchestrator.py | 69 +++++++-------- .../pkg/agent/runner/query_entry_adapter.py | 87 +++++++++++++++---- src/langbot/pkg/agent/runner/run_journal.py | 11 ++- .../pkg/agent/runner/session_registry.py | 3 + .../pkg/agent/runner/transcript_store.py | 21 ++++- .../pkg/entity/persistence/transcript.py | 2 +- .../versions/0005_migrate_runner_config.py | 25 +++++- src/langbot/pkg/plugin/handler.py | 62 +++++++------ .../agent/test_event_first_protocol.py | 45 ++++++++++ tests/unit_tests/agent/test_handler_auth.py | 69 +++++++++++++++ .../agent/test_orchestrator_integration.py | 54 ++++++++++++ .../unit_tests/agent/test_session_registry.py | 14 +++ 16 files changed, 450 insertions(+), 103 deletions(-) diff --git a/docs/agent-runner-pluginization/PROTOCOL_V1.md b/docs/agent-runner-pluginization/PROTOCOL_V1.md index f0e55eb5..7e504498 100644 --- a/docs/agent-runner-pluginization/PROTOCOL_V1.md +++ b/docs/agent-runner-pluginization/PROTOCOL_V1.md @@ -324,6 +324,7 @@ class InlineContextPolicy(BaseModel): reason: str | None = None class ContextAPICapabilities(BaseModel): + prompt_get: bool = False history_page: bool = False history_search: bool = False event_get: bool = False @@ -473,8 +474,8 @@ Host 必须校验 `state.updated` 的 scope、key、value 大小和 JSON 可序 ```python # Model -await api.invoke_llm(model_id, messages, funcs=None, extra_args=None) -async for chunk in api.invoke_llm_stream(model_id, messages, funcs=None, extra_args=None): +await api.invoke_llm(llm_model_uuid, messages, funcs=None, extra_args=None) +async for chunk in api.invoke_llm_stream(llm_model_uuid, messages, funcs=None, extra_args=None): ... await api.invoke_rerank(rerank_model_id, query, documents, top_k=None) @@ -486,13 +487,14 @@ await api.call_tool(tool_name, parameters) await api.retrieve_knowledge(kb_id, query_text, top_k=5, filters=None) # History(返回 Transcript projection,不返回原始平台 payload) +await api.get_prompt() await api.history_page(conversation_id=None, before_cursor=None, after_cursor=None, limit=50, direction="backward", include_artifacts=False) await api.history_search(query, filters=None, top_k=10) # Event(返回稳定 event envelope 或受限 raw ref,不默认返回大 payload) await api.event_get(event_id) -await api.event_page(before_cursor=None, limit=50) +await api.event_page(conversation_id=None, event_types=None, before_cursor=None, limit=50) await api.steering_pull(mode="all", limit=None) # Artifact(必须支持大小限制、MIME 校验、过期时间和授权范围) @@ -502,7 +504,7 @@ await api.artifact_read_range(artifact_id, offset=0, length=65536) # State / Storage await api.state_get(scope, key); await api.state_set(scope, key, value); await api.state_delete(scope, key) -await api.state_list(scope, prefix=None) +await api.state_list(scope, prefix=None, limit=100) await api.get_plugin_storage(key); await api.set_plugin_storage(key, value); await api.delete_plugin_storage(key) await api.get_plugin_storage_keys() await api.get_workspace_storage(key); await api.set_workspace_storage(key, value); await api.delete_workspace_storage(key) @@ -513,6 +515,15 @@ await api.get_file(file_key) await api.get_langbot_version() ``` +`invoke_llm()` / `invoke_llm_stream()` 的第一个参数在 SDK 中命名为 +`llm_model_uuid`,wire payload 字段也是 `llm_model_uuid`。该值对 runner +仍是 opaque identifier,不应解析其内部格式。 + +`get_prompt()` 返回当前 query-backed run 的 Host effective prompt messages: +`list[Message]` 的 JSON 形式。该能力只在 `ctx.context.available_apis.prompt_get` +为 true 时可用;没有 query 缓存、prompt 已过期或非 query entry run 时 Host +可以返回错误或空列表。Runner 应在不可用时回退到自己的 config/prompt 策略。 + `steering_pull(mode="all")` 是推荐默认:Host 按 claim 顺序返回全部 pending steering 输入并清空对应队列。`mode="one-at-a-time"` 仅用于 runner 主动节流,每次返回一条。Host 不合并多条用户消息;runner 负责在 turn 边界决定模型侧格式。 Steering 审计使用 EventLog 而不是 Transcript schema 扩展:被 active run 吸收的原始 `message.received` 事件保留原事件类型,并在 `metadata.steering` 标记 `status="queued"`、`trigger_behavior="absorbed_into_active_run"`、`claimed_by_run_id`、`claimed_runner_id`、`claimed_at`。Runner 成功 pull 后,Host 追加 `steering.injected` EventLog 记录,`metadata.steering.status="injected"` 并引用 `source_event_id`。若 run 结束时仍有已 claim 但未 pull 的 steering 输入,Host 追加 `steering.dropped` EventLog 记录,`metadata.steering.status="dropped"` 并引用 `source_event_id`;这不是用户消息事实的删除,只是 dispatch 终态。Transcript 继续只表示会话事实,不承担 dispatch 行为标记。 diff --git a/src/langbot/pkg/agent/runner/artifact_store.py b/src/langbot/pkg/agent/runner/artifact_store.py index 1aebdc76..9afdbce8 100644 --- a/src/langbot/pkg/agent/runner/artifact_store.py +++ b/src/langbot/pkg/agent/runner/artifact_store.py @@ -8,6 +8,7 @@ import uuid import base64 import os +import aiofiles import sqlalchemy from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import sessionmaker @@ -17,6 +18,23 @@ from ...entity.persistence.bstorage import BinaryStorage _FILE_ARTIFACT_METADATA_KEY = '_langbot_file_artifact' _ARTIFACT_THREAD_METADATA_KEY = '_langbot_thread_id' +UTC = datetime.timezone.utc + + +def _utc_now() -> datetime.datetime: + return datetime.datetime.now(UTC) + + +def _as_utc(value: datetime.datetime) -> datetime.datetime: + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value.astimezone(UTC) + + +def _datetime_to_epoch(value: datetime.datetime | None) -> int | None: + if value is None: + return None + return int(_as_utc(value).timestamp()) class ArtifactStore: @@ -191,7 +209,7 @@ class ArtifactStore: runner_id=runner_id, bot_id=bot_id, workspace_id=workspace_id, - created_at=datetime.datetime.utcnow(), + created_at=_utc_now(), expires_at=expires_at, metadata_json=json.dumps(metadata_payload) if metadata_payload else None, ) @@ -336,7 +354,7 @@ class ArtifactStore: } if storage_type == 'file': - return self._read_file_storage(record, artifact_id, offset, limit) + return await self._read_file_storage(record, artifact_id, offset, limit) # For other storage types, return storage reference # (caller can use file_key for chunked transfer) @@ -363,7 +381,7 @@ class ArtifactStore: backing lifecycle remains owned by the storage provider. """ if now is None: - now = datetime.datetime.utcnow() + now = _utc_now() async with self._session_factory() as session: result = await session.execute( @@ -416,7 +434,7 @@ class ArtifactStore: return None return row.value - def _read_file_storage( + async def _read_file_storage( self, record: AgentArtifact, artifact_id: str, @@ -441,9 +459,9 @@ class ArtifactStore: if offset >= file_size: content = b'' else: - with open(real_path, 'rb') as f: - f.seek(offset) - content = f.read(limit) + async with aiofiles.open(real_path, 'rb') as f: + await f.seek(offset) + content = await f.read(limit) return { 'artifact_id': artifact_id, @@ -491,8 +509,8 @@ class ArtifactStore: if row.expires_at is None: return False if now is None: - now = datetime.datetime.utcnow() - return row.expires_at <= now + now = _utc_now() + return _as_utc(row.expires_at) <= _as_utc(now) def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]: """Convert an AgentArtifact row to public dict. @@ -511,7 +529,7 @@ class ArtifactStore: 'conversation_id': row.conversation_id, 'run_id': row.run_id, 'runner_id': row.runner_id, - 'created_at': int(row.created_at.timestamp()) if row.created_at else None, - 'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None, + 'created_at': _datetime_to_epoch(row.created_at), + 'expires_at': _datetime_to_epoch(row.expires_at), 'metadata': self._public_metadata(row.metadata_json), } diff --git a/src/langbot/pkg/agent/runner/binding_resolver.py b/src/langbot/pkg/agent/runner/binding_resolver.py index 6d05ba76..fea5ad6b 100644 --- a/src/langbot/pkg/agent/runner/binding_resolver.py +++ b/src/langbot/pkg/agent/runner/binding_resolver.py @@ -22,7 +22,14 @@ class AgentBindingResolver: event: AgentEventEnvelope, agents: list[AgentConfig], ) -> AgentBinding: - """Resolve exactly one enabled Agent for the event.""" + """Resolve exactly one enabled Agent for the event. + + Callers that source agents from bot/workspace/global configuration must + pre-filter candidates to the event scope before calling this resolver. + The current AgentConfig model represents one already-selected product + Agent and does not carry enough scope metadata to make that decision + safely here. + """ matches = [ agent for agent in agents diff --git a/src/langbot/pkg/agent/runner/event_log_store.py b/src/langbot/pkg/agent/runner/event_log_store.py index d8c5b4e3..eb727714 100644 --- a/src/langbot/pkg/agent/runner/event_log_store.py +++ b/src/langbot/pkg/agent/runner/event_log_store.py @@ -13,6 +13,23 @@ from sqlalchemy.orm import sessionmaker from ...entity.persistence.event_log import EventLog +UTC = datetime.timezone.utc + + +def _utc_now() -> datetime.datetime: + return datetime.datetime.now(UTC) + + +def _datetime_to_epoch(value: datetime.datetime | None) -> int | None: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=UTC) + else: + value = value.astimezone(UTC) + return int(value.timestamp()) + + class EventLogStore: """Store for EventLog records. @@ -107,7 +124,7 @@ class EventLogStore: run_id=run_id, runner_id=runner_id, metadata_json=json.dumps(metadata) if metadata else None, - created_at=datetime.datetime.utcnow(), + created_at=_utc_now(), ) session.add(event) await session.commit() @@ -277,7 +294,7 @@ class EventLogStore: 'id': row.id, 'event_id': row.event_id, 'event_type': row.event_type, - 'event_time': int(row.event_time.timestamp()) if row.event_time else None, + 'event_time': _datetime_to_epoch(row.event_time), 'source': row.source, 'bot_id': row.bot_id, 'workspace_id': row.workspace_id, @@ -293,6 +310,6 @@ class EventLogStore: 'raw_ref': row.raw_ref, 'run_id': row.run_id, 'runner_id': row.runner_id, - 'created_at': int(row.created_at.timestamp()) if row.created_at else None, + 'created_at': _datetime_to_epoch(row.created_at), 'metadata': json.loads(row.metadata_json) if row.metadata_json else {}, } diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index 98d61486..a084dcc2 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -103,40 +103,6 @@ class AgentRunOrchestrator: state_context = build_state_context(event, binding, descriptor) run_id = context['run_id'] - await self._session_registry.register( - run_id=run_id, - runner_id=descriptor.id, - query_id=session_query_id, - plugin_identity=descriptor.get_plugin_id(), - resources=resources, - available_apis=context.get('context', {}).get('available_apis'), - conversation_id=event.conversation_id, - bot_id=event.bot_id, - workspace_id=event.workspace_id, - thread_id=event.thread_id, - state_policy={ - 'enable_state': binding.state_policy.enable_state, - 'state_scopes': list(binding.state_policy.state_scopes), - }, - state_context=state_context, - ) - - event_log_id = await self.journal.write_event_log( - event=event, - binding=binding, - run_id=run_id, - runner_id=descriptor.id, - ) - await self.journal.register_input_artifacts( - event=event, - run_id=run_id, - runner_id=descriptor.id, - ) - if event.event_type == 'message.received' and event.conversation_id: - await self.journal.write_user_transcript( - event=event, - event_log_id=event_log_id, - ) pending_artifact_refs: list[dict[str, typing.Any]] = [] seen_sequences: set[int] = set() @@ -144,6 +110,41 @@ class AgentRunOrchestrator: assistant_transcript_written = False try: + await self._session_registry.register( + run_id=run_id, + runner_id=descriptor.id, + query_id=session_query_id, + plugin_identity=descriptor.get_plugin_id(), + resources=resources, + available_apis=context.get('context', {}).get('available_apis'), + conversation_id=event.conversation_id, + bot_id=event.bot_id, + workspace_id=event.workspace_id, + thread_id=event.thread_id, + state_policy={ + 'enable_state': binding.state_policy.enable_state, + 'state_scopes': list(binding.state_policy.state_scopes), + }, + state_context=state_context, + ) + + event_log_id = await self.journal.write_event_log( + event=event, + binding=binding, + run_id=run_id, + runner_id=descriptor.id, + ) + await self.journal.register_input_artifacts( + event=event, + run_id=run_id, + runner_id=descriptor.id, + ) + if event.event_type == 'message.received' and event.conversation_id: + await self.journal.write_user_transcript( + event=event, + event_log_id=event_log_id, + ) + async for result_dict in self.invoker.invoke(descriptor, context): sequence = result_dict.get('sequence') if sequence is not None: diff --git a/src/langbot/pkg/agent/runner/query_entry_adapter.py b/src/langbot/pkg/agent/runner/query_entry_adapter.py index 23ac81f5..ad15730c 100644 --- a/src/langbot/pkg/agent/runner/query_entry_adapter.py +++ b/src/langbot/pkg/agent/runner/query_entry_adapter.py @@ -27,6 +27,7 @@ from .host_models import ( StatePolicy, DeliveryPolicy, ) +from .config_migration import ConfigMigration from . import events as runner_events @@ -42,6 +43,7 @@ class QueryEntryAdapter: INTERNAL_PREFIX = '_' SENSITIVE_PATTERNS = ('secret', 'token', 'key', 'password', 'credential', 'api_key', 'apikey') PERMISSION_VARS = ('_pipeline_bound_plugins', '_authorized', '_permission') + EVENT_DATA_MAX_STRING_BYTES = 512 @classmethod def query_to_event( @@ -103,8 +105,7 @@ class QueryEntryAdapter: ) -> AgentConfig: """Project the current Pipeline config container into target Agent config.""" pipeline_config = query.pipeline_config or {} - ai_config = pipeline_config.get('ai', {}) - runner_config = ai_config.get('runner_config', {}).get(runner_id, {}) + runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id) agent_id = getattr(query, 'pipeline_uuid', None) # Build resource policy from current config @@ -199,12 +200,13 @@ class QueryEntryAdapter: event_data: dict[str, typing.Any] = {} if message_event and hasattr(message_event, 'model_dump'): try: - event_data = message_event.model_dump(mode='json') + raw_event_data = message_event.model_dump(mode='json') except TypeError: - event_data = message_event.model_dump() + raw_event_data = message_event.model_dump() except Exception: - event_data = {} - event_data.pop('source_platform_object', None) + raw_event_data = {} + if isinstance(raw_event_data, dict): + event_data = cls._compact_event_data(raw_event_data) source_event_type = None if message_event: @@ -231,6 +233,25 @@ class QueryEntryAdapter: data=event_data, ) + @classmethod + def _compact_event_data( + cls, + event_data: dict[str, typing.Any], + ) -> dict[str, typing.Any]: + """Keep only small scalar source-event metadata in event.data.""" + compact: dict[str, typing.Any] = {} + for key, value in event_data.items(): + if key == 'source_platform_object' or key.startswith('_'): + continue + if value is None or isinstance(value, (bool, int, float)): + compact[key] = value + continue + if isinstance(value, str): + if len(value.encode('utf-8')) <= cls.EVENT_DATA_MAX_STRING_BYTES: + compact[key] = value + continue + return compact + @classmethod def _build_scoped_event_id( cls, @@ -430,6 +451,18 @@ class QueryEntryAdapter: import uuid attachments: list[dict[str, typing.Any]] = [] + seen_keys: dict[tuple[str, str, str], set[str]] = {} + + def add_attachment(attachment: dict[str, typing.Any]) -> None: + key = cls._attachment_dedupe_key(attachment) + if key is not None: + source = str(attachment.get('source') or '') + sources = seen_keys.setdefault(key, set()) + if source and sources and source not in sources: + return + if source: + sources.add(source) + attachments.append(attachment) for elem in contents: elem_type = elem.get('type') @@ -437,21 +470,21 @@ class QueryEntryAdapter: if elem_type == 'image_url': image_url = elem.get('image_url') or {} - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'image', 'source': 'url', 'url': image_url.get('url') if isinstance(image_url, dict) else str(image_url), }) elif elem_type == 'image_base64': - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'image', 'source': 'base64', 'content': elem.get('image_base64'), }) elif elem_type == 'file_url': - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'file', 'source': 'url', @@ -459,7 +492,7 @@ class QueryEntryAdapter: 'name': elem.get('file_name'), }) elif elem_type == 'file_base64': - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'file', 'source': 'base64', @@ -478,32 +511,56 @@ class QueryEntryAdapter: artifact_id = str(uuid.uuid4()) # Generate unique ID if isinstance(component, platform_message.Image): - attachments.append({ + image_id = component.image_id or None + image_url = component.url or None + image_base64 = component.base64 or None + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'image', 'source': 'message_chain', - 'id': component.image_id or None, - 'url': component.url or None, + 'id': image_id, + 'url': image_url, + 'content': image_base64, }) elif isinstance(component, platform_message.File): - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'file', 'source': 'message_chain', 'id': component.id or None, 'name': component.name or None, + 'url': component.url or None, + 'content': component.base64 or None, }) elif isinstance(component, platform_message.Voice): - attachments.append({ + add_attachment({ 'artifact_id': artifact_id, 'artifact_type': 'voice', 'source': 'message_chain', 'id': component.voice_id or None, 'url': component.url or None, + 'content': component.base64 or None, }) return attachments + @classmethod + def _attachment_dedupe_key( + cls, + attachment: dict[str, typing.Any], + ) -> tuple[str, str, str] | None: + """Return a stable key for the same attachment across content sources.""" + artifact_type = attachment.get('artifact_type') + if not artifact_type: + return None + for field in ('id', 'url', 'content'): + value = attachment.get(field) + if value: + if field == 'content': + value = hashlib.sha256(str(value).encode('utf-8')).hexdigest() + return str(artifact_type), field, str(value) + return None + @classmethod def _build_delivery_context( cls, diff --git a/src/langbot/pkg/agent/runner/run_journal.py b/src/langbot/pkg/agent/runner/run_journal.py index 812939f9..5a672cdb 100644 --- a/src/langbot/pkg/agent/runner/run_journal.py +++ b/src/langbot/pkg/agent/runner/run_journal.py @@ -165,7 +165,11 @@ class AgentRunJournal: input_json=input_json, run_id=run_id, runner_id=runner_id, - event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None, + event_time=( + datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc) + if event.event_time + else None + ), metadata=metadata, ) @@ -468,7 +472,10 @@ class AgentRunJournal: raw_event_time = event.get('event_time') if raw_event_time: try: - event_time = datetime.datetime.fromtimestamp(raw_event_time) + event_time = datetime.datetime.fromtimestamp( + raw_event_time, + datetime.timezone.utc, + ) except (TypeError, ValueError, OSError): event_time = None diff --git a/src/langbot/pkg/agent/runner/session_registry.py b/src/langbot/pkg/agent/runner/session_registry.py index 6cf37a1b..b03e2fc7 100644 --- a/src/langbot/pkg/agent/runner/session_registry.py +++ b/src/langbot/pkg/agent/runner/session_registry.py @@ -122,6 +122,9 @@ class AgentRunSessionRegistry: state_policy: State policy from binding (enable_state, state_scopes) state_context: Context for state API (scope_keys, binding_identity, etc.) """ + if not isinstance(plugin_identity, str) or not plugin_identity.strip(): + raise ValueError('plugin_identity is required for agent run sessions') + now = int(time.time()) available_apis = copy.deepcopy(available_apis or {}) diff --git a/src/langbot/pkg/agent/runner/transcript_store.py b/src/langbot/pkg/agent/runner/transcript_store.py index e594c9ea..49613ed9 100644 --- a/src/langbot/pkg/agent/runner/transcript_store.py +++ b/src/langbot/pkg/agent/runner/transcript_store.py @@ -14,6 +14,23 @@ from ...entity.persistence.transcript import Transcript from langbot_plugin.api.entities.builtin.provider import message as provider_message +UTC = datetime.timezone.utc + + +def _utc_now() -> datetime.datetime: + return datetime.datetime.now(UTC) + + +def _datetime_to_epoch(value: datetime.datetime | None) -> int | None: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=UTC) + else: + value = value.astimezone(UTC) + return int(value.timestamp()) + + class TranscriptStore: """Store for Transcript records. @@ -94,7 +111,7 @@ class TranscriptStore: seq=0, run_id=run_id, runner_id=runner_id, - created_at=datetime.datetime.utcnow(), + created_at=_utc_now(), metadata_json=json.dumps(metadata) if metadata else None, ) session.add(item) @@ -371,7 +388,7 @@ class TranscriptStore: 'content_json': json.loads(row.content_json) if row.content_json else None, 'seq': row.seq, 'cursor': str(row.seq), - 'created_at': int(row.created_at.timestamp()) if row.created_at else None, + 'created_at': _datetime_to_epoch(row.created_at), 'metadata': json.loads(row.metadata_json) if row.metadata_json else {}, } diff --git a/src/langbot/pkg/entity/persistence/transcript.py b/src/langbot/pkg/entity/persistence/transcript.py index 0c8f4737..3bbdf1e6 100644 --- a/src/langbot/pkg/entity/persistence/transcript.py +++ b/src/langbot/pkg/entity/persistence/transcript.py @@ -55,7 +55,7 @@ class Transcript(Base): """Artifact references as JSON string (list of ArtifactRef).""" # Sequence for cursor-based pagination - seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True) + seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) """Monotonic cursor sequence for pagination.""" # Context diff --git a/src/langbot/pkg/persistence/alembic/versions/0005_migrate_runner_config.py b/src/langbot/pkg/persistence/alembic/versions/0005_migrate_runner_config.py index aa47bdb0..8e7aa42b 100644 --- a/src/langbot/pkg/persistence/alembic/versions/0005_migrate_runner_config.py +++ b/src/langbot/pkg/persistence/alembic/versions/0005_migrate_runner_config.py @@ -29,6 +29,26 @@ def _load_config(config_value): return None +def _update_config(conn, table_name: str, pipeline_uuid: str, migrated_config: dict) -> None: + """Write JSON config using a dialect-compatible bind.""" + config_json = json.dumps(migrated_config) + if conn.dialect.name == 'postgresql': + conn.execute( + sa.text( + f'UPDATE {table_name} ' + 'SET config = CAST(:config AS JSON) ' + 'WHERE uuid = :uuid' + ), + {'config': config_json, 'uuid': pipeline_uuid}, + ) + return + + conn.execute( + sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'), + {'config': config_json, 'uuid': pipeline_uuid}, + ) + + def upgrade() -> None: """Normalize existing pipeline config containers.""" conn = op.get_bind() @@ -56,10 +76,7 @@ def upgrade() -> None: # Only update if config changed if json.dumps(config, sort_keys=True) != json.dumps(migrated_config, sort_keys=True): - conn.execute( - sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'), - {'config': json.dumps(migrated_config), 'uuid': pipeline_uuid}, - ) + _update_config(conn, table_name, pipeline_uuid, migrated_config) except Exception: # Skip invalid configs continue diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 53dc7927..41bf3f4f 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -211,19 +211,23 @@ async def _validate_agent_run_session( ) session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return None, handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - ap.logger.warning( - f'{api_name}: caller_plugin_identity {caller_plugin_identity} ' - f'does not match session plugin_identity {session_plugin_identity}' - ) - return None, handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) + if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip(): + ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity') + return None, handler.ActionResponse.error( + message=f'Run session {run_id} has no plugin_identity' + ) + if not caller_plugin_identity: + return None, handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + ap.logger.warning( + f'{api_name}: caller_plugin_identity {caller_plugin_identity} ' + f'does not match session plugin_identity {session_plugin_identity}' + ) + return None, handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) if api_capability: available_apis = _get_run_authorization(session).get('available_apis', {}) @@ -384,19 +388,25 @@ async def _validate_run_authorization( ) session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return None, handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}', - ) - if caller_plugin_identity != session_plugin_identity: - ap.logger.warning( - f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} ' - f'does not match session plugin_identity {session_plugin_identity}' - ) - return None, handler.ActionResponse.error( - message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}', - ) + if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip(): + ap.logger.warning( + f'{resource_type.upper()}: run_id {run_id} has no plugin_identity' + ) + return None, handler.ActionResponse.error( + message=f'Run session {run_id} has no plugin_identity', + ) + if not caller_plugin_identity: + return None, handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}', + ) + if caller_plugin_identity != session_plugin_identity: + ap.logger.warning( + f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} ' + f'does not match session plugin_identity {session_plugin_identity}' + ) + return None, handler.ActionResponse.error( + message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}', + ) if not session_registry.is_resource_allowed(session, resource_type, resource_id, operation): ap.logger.warning( diff --git a/tests/unit_tests/agent/test_event_first_protocol.py b/tests/unit_tests/agent/test_event_first_protocol.py index 9ee71f39..3ed93ec4 100644 --- a/tests/unit_tests/agent/test_event_first_protocol.py +++ b/tests/unit_tests/agent/test_event_first_protocol.py @@ -106,6 +106,28 @@ class TestQueryToEventEnvelope: "message_id": "source-message-1", } + def test_query_to_event_keeps_large_payloads_out_of_event_data(self, mock_query): + """Large or nested platform payloads should not be duplicated into event.data.""" + source_event = Mock() + source_event.type = "platform.message.created" + source_event.time = 1700000000 + source_event.sender = None + source_event.model_dump = Mock(return_value={ + "type": "platform.message.created", + "message_id": "source-message-1", + "message_chain": [{"type": "Image", "base64": "data:image/png;base64," + ("a" * 1024)}], + "raw_text": "x" * 1024, + "source_platform_object": {"large": "payload"}, + }) + mock_query.message_event = source_event + + event = QueryEntryAdapter.query_to_event(mock_query) + + assert event.data == { + "type": "platform.message.created", + "message_id": "source-message-1", + } + def test_query_to_event_handles_missing_message_chain(self, mock_query): """Test delivery context building when Query has no message_chain.""" delattr(mock_query, "message_chain") @@ -137,6 +159,29 @@ class TestQueryConfigToAgentConfig: assert agent_config.runner_id == "plugin:author/plugin/runner" + def test_config_to_agent_config_uses_legacy_runner_config_migration(self, mock_query): + """Temporary query adapter must share the normal runner config resolver.""" + mock_query.pipeline_config = { + "ai": { + "runner": {"runner": "local-agent"}, + "local-agent": { + "model": "model-primary", + "knowledge-base": "kb-001", + }, + } + } + + agent_config = QueryEntryAdapter.config_to_agent_config( + mock_query, + "plugin:langbot/local-agent/default", + ) + + assert agent_config.runner_config["model"] == { + "primary": "model-primary", + "fallbacks": [], + } + assert agent_config.runner_config["knowledge-bases"] == ["kb-001"] + def test_resolver_projects_agent_scope(self, mock_query): """Test binding scope projection through the resolver.""" event = QueryEntryAdapter.query_to_event(mock_query) diff --git a/tests/unit_tests/agent/test_handler_auth.py b/tests/unit_tests/agent/test_handler_auth.py index d8cfc482..61bbd537 100644 --- a/tests/unit_tests/agent/test_handler_auth.py +++ b/tests/unit_tests/agent/test_handler_auth.py @@ -2024,6 +2024,75 @@ class TestCallerPluginIdentityValidation: await registry.unregister('run_no_caller_identity') + @pytest.mark.asyncio + async def test_session_missing_plugin_identity_denied(self): + """Malformed legacy sessions without plugin_identity fail closed.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + + registry = get_session_registry() + resources = make_resources(models=[{'model_id': 'model_001'}]) + session = make_session( + run_id='run_missing_session_identity', + runner_id='plugin:test/runner/default', + plugin_identity='', + resources=resources, + ) + async with registry._lock: + registry._sessions['run_missing_session_identity'] = session + + from langbot.pkg.plugin.handler import _validate_run_authorization + + mock_ap = MagicMock() + mock_ap.logger = MagicMock() + + session, error = await _validate_run_authorization( + 'run_missing_session_identity', + 'model', + 'model_001', + mock_ap, + caller_plugin_identity='test/runner', + ) + + assert session is None + assert error is not None + assert 'no plugin_identity' in error.message + + await registry.unregister('run_missing_session_identity') + + @pytest.mark.asyncio + async def test_pull_api_session_missing_plugin_identity_denied(self): + """Pull API validation also fails closed for missing session identity.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + + registry = get_session_registry() + session = make_session( + run_id='run_missing_pull_identity', + runner_id='plugin:test/runner/default', + plugin_identity='', + available_apis={'history_page': True}, + ) + async with registry._lock: + registry._sessions['run_missing_pull_identity'] = session + + from langbot.pkg.plugin.handler import _validate_agent_run_session + + mock_ap = MagicMock() + mock_ap.logger = MagicMock() + + session, error = await _validate_agent_run_session( + 'run_missing_pull_identity', + 'test/runner', + mock_ap, + 'HISTORY_PAGE', + 'history_page', + ) + + assert session is None + assert error is not None + assert 'no plugin_identity' in error.message + + await registry.unregister('run_missing_pull_identity') + class TestBackwardCompatStorageNoRunId: """Tests for unscoped storage actions without run_id. diff --git a/tests/unit_tests/agent/test_orchestrator_integration.py b/tests/unit_tests/agent/test_orchestrator_integration.py index 9c803d7b..4cee5ba4 100644 --- a/tests/unit_tests/agent/test_orchestrator_integration.py +++ b/tests/unit_tests/agent/test_orchestrator_integration.py @@ -275,6 +275,37 @@ def test_context_builder_includes_consumable_base64_attachments(): assert input_data.attachments[1].name == "hello.txt" +def test_context_builder_deduplicates_message_chain_attachments(): + query = make_query() + query.user_message = None + query.message_chain = platform_message.MessageChain( + [platform_message.Image(base64="data:image/jpeg;base64,aGVsbG8=")] + ) + + input_data = QueryEntryAdapter._build_input(query) + + assert [content.type for content in input_data.contents] == ["image_base64"] + assert len(input_data.attachments) == 1 + assert input_data.attachments[0].artifact_type == "image" + assert input_data.attachments[0].content == "data:image/jpeg;base64,aGVsbG8=" + + +def test_context_builder_preserves_same_source_duplicate_attachments(): + query = make_query() + query.user_message = provider_message.Message( + role="user", + content=[ + provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), + provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), + ], + ) + query.message_chain = platform_message.MessageChain([]) + + input_data = QueryEntryAdapter._build_input(query) + + assert [attachment.artifact_type for attachment in input_data.attachments] == ["image", "image"] + + @pytest.fixture(autouse=True) async def clean_agent_state(): """Reset all singleton stores and create a test database engine.""" @@ -546,6 +577,29 @@ async def test_orchestrator_unregisters_session_after_runner_failure(clean_agent assert await get_session_registry().get(context["run_id"]) is None +@pytest.mark.asyncio +async def test_orchestrator_unregisters_session_after_event_log_failure(clean_agent_state): + """Journal failures before runner invocation must not leave steerable sessions.""" + db_engine = clean_agent_state + descriptor = make_descriptor() + plugin_connector = FakePluginConnector( + results=[ + { + "type": "message.completed", + "data": {"message": {"role": "assistant", "content": "unused"}}, + } + ] + ) + orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor)) + orchestrator.journal.write_event_log = AsyncMock(side_effect=RuntimeError("journal unavailable")) + + with pytest.raises(RuntimeError, match="journal unavailable"): + [message async for message in orchestrator.run_from_query(make_query())] + + assert plugin_connector.contexts == [] + assert await get_session_registry().list_active_runs() == [] + + @pytest.mark.asyncio async def test_orchestrator_enforces_total_runner_deadline(clean_agent_state): """Test that orchestrator enforces total runner timeout.""" diff --git a/tests/unit_tests/agent/test_session_registry.py b/tests/unit_tests/agent/test_session_registry.py index 4e9704ee..99218484 100644 --- a/tests/unit_tests/agent/test_session_registry.py +++ b/tests/unit_tests/agent/test_session_registry.py @@ -49,6 +49,20 @@ class TestSessionRegistryBasic: assert 'permissions' not in result assert '_authorized_ids' not in result + @pytest.mark.asyncio + async def test_register_requires_plugin_identity(self): + """Agent run sessions must always have an owning plugin identity.""" + registry = AgentRunSessionRegistry() + + with pytest.raises(ValueError, match='plugin_identity is required'): + await registry.register( + run_id='run_missing_identity', + runner_id='plugin:test/my-runner/default', + query_id=1, + plugin_identity='', + resources=make_resources(), + ) + @pytest.mark.asyncio async def test_register_freezes_authorization_snapshot(self): """Register should freeze authorization data for the run."""