From 9aa643b55f4ead2178f08c451df46f2ec8ba75e6 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <60681390+huanghuoguoguo@users.noreply.github.com> Date: Mon, 15 Jun 2026 18:09:05 +0800 Subject: [PATCH] feat(agent-runner): add host run ledger primitives --- src/langbot/pkg/agent/runner/__init__.py | 3 + .../pkg/agent/runner/context_builder.py | 33 +- src/langbot/pkg/agent/runner/orchestrator.py | 105 +- src/langbot/pkg/agent/runner/run_journal.py | 103 +- .../pkg/agent/runner/run_ledger_store.py | 600 ++++++++++++ .../pkg/entity/persistence/agent_run.py | 203 ++++ src/langbot/pkg/persistence/alembic/env.py | 1 + .../8d3a1f2c4b6e_add_agent_run_ledger.py | 203 ++++ src/langbot/pkg/plugin/handler.py | 904 ++++++++++++++++-- .../agent/test_orchestrator_integration.py | 701 ++++++++------ .../agent/test_run_ledger_api_auth.py | 490 ++++++++++ .../unit_tests/agent/test_run_ledger_store.py | 167 ++++ 12 files changed, 3088 insertions(+), 425 deletions(-) create mode 100644 src/langbot/pkg/agent/runner/run_ledger_store.py create mode 100644 src/langbot/pkg/entity/persistence/agent_run.py create mode 100644 src/langbot/pkg/persistence/alembic/versions/8d3a1f2c4b6e_add_agent_run_ledger.py create mode 100644 tests/unit_tests/agent/test_run_ledger_api_auth.py create mode 100644 tests/unit_tests/agent/test_run_ledger_store.py diff --git a/src/langbot/pkg/agent/runner/__init__.py b/src/langbot/pkg/agent/runner/__init__.py index cebed269..0dc533aa 100644 --- a/src/langbot/pkg/agent/runner/__init__.py +++ b/src/langbot/pkg/agent/runner/__init__.py @@ -1,4 +1,5 @@ """Agent runner modules.""" + from __future__ import annotations from .descriptor import AgentRunnerDescriptor @@ -24,6 +25,7 @@ from .session_registry import ( RunAuthorizationSnapshot, get_session_registry, ) +from .run_ledger_store import RunLedgerStore from .events import ( MESSAGE_RECEIVED, MESSAGE_RECALLED, @@ -55,6 +57,7 @@ __all__ = [ 'AgentRunSession', 'RunAuthorizationSnapshot', 'get_session_registry', + 'RunLedgerStore', 'MESSAGE_RECEIVED', 'MESSAGE_RECALLED', 'GROUP_MEMBER_JOINED', diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index 665fce64..0b99a52a 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -420,7 +420,21 @@ class AgentRunContextBuilder: event_page_enabled = 'page' in event_perms and conversation_id is not None artifact_metadata_enabled = 'metadata' in artifact_perms artifact_read_enabled = 'read' in artifact_perms - steering_pull_enabled = bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None + steering_pull_enabled = ( + bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None + ) + run_get_enabled = True + run_list_enabled = conversation_id is not None + run_events_page_enabled = True + run_cancel_enabled = True + run_append_result_enabled = False + run_finalize_enabled = False + run_claim_enabled = False + run_renew_claim_enabled = False + run_release_claim_enabled = False + runtime_register_enabled = False + runtime_heartbeat_enabled = False + runtime_list_enabled = False # Determine state API availability based on binding state_policy. state_enabled = False @@ -431,9 +445,8 @@ class AgentRunContextBuilder: state_enabled = True resource_policy = binding.resource_policy - storage_enabled = ( - ('plugin' in storage_perms and resource_policy.allow_plugin_storage) - or ('workspace' in storage_perms and resource_policy.allow_workspace_storage) + storage_enabled = ('plugin' in storage_perms and resource_policy.allow_plugin_storage) or ( + 'workspace' in storage_perms and resource_policy.allow_workspace_storage ) # Get latest cursor and has_history_before if conversation exists @@ -477,5 +490,17 @@ class AgentRunContextBuilder: 'state': state_enabled, 'storage': storage_enabled, 'steering_pull': steering_pull_enabled, + 'run_get': run_get_enabled, + 'run_list': run_list_enabled, + 'run_events_page': run_events_page_enabled, + 'run_cancel': run_cancel_enabled, + 'run_append_result': run_append_result_enabled, + 'run_finalize': run_finalize_enabled, + 'run_claim': run_claim_enabled, + 'run_renew_claim': run_renew_claim_enabled, + 'run_release_claim': run_release_claim_enabled, + 'runtime_register': runtime_register_enabled, + 'runtime_heartbeat': runtime_heartbeat_enabled, + 'runtime_list': runtime_list_enabled, }, } diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index a084dcc2..3205840b 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -103,13 +103,40 @@ class AgentRunOrchestrator: state_context = build_state_context(event, binding, descriptor) run_id = context['run_id'] + available_apis = context.get('context', {}).get('available_apis') + run_authorization = { + 'runner_id': descriptor.id, + 'binding_id': binding.binding_id, + 'plugin_identity': descriptor.get_plugin_id(), + 'resources': resources, + 'available_apis': available_apis, + 'conversation_id': event.conversation_id, + 'bot_id': event.bot_id, + 'workspace_id': event.workspace_id, + 'thread_id': event.thread_id, + 'state_policy': { + 'enable_state': binding.state_policy.enable_state, + 'state_scopes': list(binding.state_policy.state_scopes), + }, + 'state_context': state_context, + } pending_artifact_refs: list[dict[str, typing.Any]] = [] seen_sequences: set[int] = set() last_sequence = 0 assistant_transcript_written = False + terminal_status: str | None = None + terminal_reason: str | None = None + terminal_usage: dict[str, typing.Any] | None = None try: + await self.journal.create_run( + event=event, + binding=binding, + descriptor=descriptor, + context=context, + authorization=run_authorization, + ) await self._session_registry.register( run_id=run_id, runner_id=descriptor.id, @@ -146,14 +173,15 @@ class AgentRunOrchestrator: ) async for result_dict in self.invoker.invoke(descriptor, context): + result_dict = dict(result_dict) sequence = result_dict.get('sequence') if sequence is not None: try: sequence_int = int(sequence) except (TypeError, ValueError): - self.ap.logger.warning( - f'Runner {descriptor.id} returned invalid result sequence: {sequence}' - ) + self.ap.logger.warning(f'Runner {descriptor.id} returned invalid result sequence: {sequence}') + sequence_int = last_sequence + 1 + result_dict['sequence'] = sequence_int else: if sequence_int in seen_sequences: self.ap.logger.warning( @@ -166,6 +194,8 @@ class AgentRunOrchestrator: f'Runner {descriptor.id} returned non-positive result sequence ' f'{sequence_int} for run {run_id}' ) + sequence_int = last_sequence + 1 + result_dict['sequence'] = sequence_int elif last_sequence and sequence_int != last_sequence + 1: self.ap.logger.warning( f'Runner {descriptor.id} result sequence gap or out-of-order ' @@ -173,6 +203,11 @@ class AgentRunOrchestrator: ) seen_sequences.add(sequence_int) last_sequence = max(last_sequence, sequence_int) + else: + sequence_int = last_sequence + 1 + result_dict['sequence'] = sequence_int + seen_sequences.add(sequence_int) + last_sequence = sequence_int result_type = result_dict.get('type') if result_type and not self.result_normalizer.validate_payload( @@ -190,9 +225,21 @@ class AgentRunOrchestrator: runner_id=descriptor.id, ) pending_artifact_refs.append(artifact_ref) + await self.journal.append_run_result( + result_dict=result_dict, + run_id=run_id, + sequence=sequence_int, + artifact_refs=[artifact_ref], + ) await self.result_normalizer.normalize(result_dict, descriptor) continue + await self.journal.append_run_result( + result_dict=result_dict, + run_id=run_id, + sequence=sequence_int, + ) + if result_type == 'state.updated': await self.journal.handle_state_updated_event( result_dict, @@ -204,13 +251,28 @@ class AgentRunOrchestrator: await self.result_normalizer.normalize(result_dict, descriptor) continue - has_completed_message = ( - result_type == 'message.completed' - or ( - result_type == 'run.completed' - and isinstance(result_dict.get('data'), dict) - and bool(result_dict['data'].get('message')) + if result_type == 'run.completed': + terminal_status = 'completed' + terminal_reason = ( + result_dict.get('data', {}).get('finish_reason') + if isinstance(result_dict.get('data'), dict) + else None ) + usage = result_dict.get('usage') + if isinstance(usage, dict): + terminal_usage = usage + elif result_type == 'run.failed': + terminal_status = 'failed' + data = result_dict.get('data') if isinstance(result_dict.get('data'), dict) else {} + terminal_reason = data.get('error') or data.get('code') + usage = result_dict.get('usage') + if isinstance(usage, dict): + terminal_usage = usage + + has_completed_message = result_type == 'message.completed' or ( + result_type == 'run.completed' + and isinstance(result_dict.get('data'), dict) + and bool(result_dict['data'].get('message')) ) if has_completed_message and event.conversation_id and not assistant_transcript_written: merged_refs = self.journal.merge_artifact_refs( @@ -231,6 +293,27 @@ class AgentRunOrchestrator: result = await self.result_normalizer.normalize(result_dict, descriptor) if result is not None: yield result + + run_snapshot = await self.journal.get_run(run_id) + if run_snapshot and run_snapshot.get('cancel_requested_at') is not None: + terminal_status = 'cancelled' + terminal_reason = run_snapshot.get('status_reason') or 'cancel_requested' + break + await self.journal.finalize_run( + run_id=run_id, + status=terminal_status or 'completed', + status_reason=terminal_reason, + usage=terminal_usage, + ) + except Exception as exc: + failed_usage = terminal_usage + await self.journal.finalize_run( + run_id=run_id, + status='timeout' if self._is_deadline_exhausted(context) else 'failed', + status_reason=str(exc), + usage=failed_usage, + ) + raise finally: session = await self._session_registry.unregister(run_id) pending_steering = session.get('steering_queue', []) if session else [] @@ -325,9 +408,7 @@ class AgentRunOrchestrator: exc_info=True, ) - self.ap.logger.info( - f'Claimed event {event.event_id} as steering input for run {target_run_id}' - ) + self.ap.logger.info(f'Claimed event {event.event_id} as steering input for run {target_run_id}') return True def _build_steering_item( diff --git a/src/langbot/pkg/agent/runner/run_journal.py b/src/langbot/pkg/agent/runner/run_journal.py index 5a672cdb..985ca643 100644 --- a/src/langbot/pkg/agent/runner/run_journal.py +++ b/src/langbot/pkg/agent/runner/run_journal.py @@ -9,6 +9,7 @@ from .descriptor import AgentRunnerDescriptor from .errors import RunnerProtocolError from .host_models import AgentBinding, AgentEventEnvelope from .persistent_state_store import PersistentStateStore, get_persistent_state_store +from .run_ledger_store import RunLedgerStore # Maximum inline artifact content size (1MB) @@ -21,10 +22,17 @@ class AgentRunJournal: ap: app.Application _persistent_state_store: PersistentStateStore | None + _run_ledger_store: RunLedgerStore | None def __init__(self, ap: app.Application): self.ap = ap self._persistent_state_store = None + self._run_ledger_store = None + + def _get_run_ledger_store(self) -> RunLedgerStore: + if self._run_ledger_store is None: + self._run_ledger_store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + return self._run_ledger_store @staticmethod def _to_plain_dict(value: typing.Any) -> dict[str, typing.Any]: @@ -64,6 +72,81 @@ class AgentRunJournal: def _sanitize_attachments(cls, attachments: typing.Iterable[typing.Any]) -> list[dict[str, typing.Any]]: return [cls._sanitize_attachment_ref(attachment) for attachment in attachments] + async def create_run( + self, + *, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + context: dict[str, typing.Any], + authorization: dict[str, typing.Any], + ) -> dict[str, typing.Any]: + """Create the Host-owned run ledger record.""" + runtime = context.get('runtime') if isinstance(context, dict) else {} + return await self._get_run_ledger_store().create_run( + run_id=context['run_id'], + event_id=event.event_id, + binding_id=binding.binding_id, + runner_id=descriptor.id, + conversation_id=event.conversation_id, + thread_id=event.thread_id, + workspace_id=event.workspace_id, + bot_id=event.bot_id, + deadline_at=runtime.get('deadline_at') if isinstance(runtime, dict) else None, + authorization=authorization, + metadata={ + 'event_type': event.event_type, + 'source': event.source, + }, + ) + + async def append_run_result( + self, + *, + result_dict: dict[str, typing.Any], + run_id: str, + sequence: int, + source: str = 'runner', + artifact_refs: list[dict[str, typing.Any]] | None = None, + metadata: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any]: + """Persist one AgentRunResult in the run ledger.""" + usage = result_dict.get('usage') + if hasattr(usage, 'model_dump'): + usage = usage.model_dump(mode='json') + return await self._get_run_ledger_store().append_event( + run_id=run_id, + sequence=sequence, + event_type=str(result_dict.get('type') or 'unknown'), + data=result_dict.get('data') if isinstance(result_dict.get('data'), dict) else {}, + usage=usage if isinstance(usage, dict) else None, + source=source, + artifact_refs=artifact_refs, + metadata=metadata, + ) + + async def finalize_run( + self, + *, + run_id: str, + status: str, + status_reason: str | None = None, + usage: dict[str, typing.Any] | None = None, + metadata: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any] | None: + """Finalize or update the Host-owned run ledger record.""" + return await self._get_run_ledger_store().finalize_run( + run_id=run_id, + status=status, + status_reason=status_reason, + usage=usage, + metadata=metadata, + ) + + async def get_run(self, run_id: str) -> dict[str, typing.Any] | None: + """Return the persisted run ledger record.""" + return await self._get_run_ledger_store().get_run(run_id) + async def handle_state_updated_event( self, result_dict: dict[str, typing.Any], @@ -99,9 +182,7 @@ class AgentRunJournal: ) if self._persistent_state_store is None: - self._persistent_state_store = get_persistent_state_store( - self.ap.persistence_mgr.get_db_engine() - ) + self._persistent_state_store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) success, error = await self._persistent_state_store.apply_update_from_event( event=event, @@ -114,13 +195,9 @@ class AgentRunJournal: ) if success: - self.ap.logger.debug( - f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}' - ) + self.ap.logger.debug(f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}') elif error: - self.ap.logger.warning( - f'Runner {descriptor.id} state.updated rejected: {error}' - ) + self.ap.logger.warning(f'Runner {descriptor.id} state.updated rejected: {error}') async def write_event_log( self, @@ -166,9 +243,7 @@ class AgentRunJournal: run_id=run_id, runner_id=runner_id, event_time=( - datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc) - if event.event_time - else None + datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc) if event.event_time else None ), metadata=metadata, ) @@ -239,9 +314,7 @@ class AgentRunJournal: content=content, ) except Exception as e: - self.ap.logger.warning( - f'Failed to register input artifact {artifact_id}: {e}' - ) + self.ap.logger.warning(f'Failed to register input artifact {artifact_id}: {e}') def decode_attachment_content( self, diff --git a/src/langbot/pkg/agent/runner/run_ledger_store.py b/src/langbot/pkg/agent/runner/run_ledger_store.py new file mode 100644 index 00000000..0c447b37 --- /dev/null +++ b/src/langbot/pkg/agent/runner/run_ledger_store.py @@ -0,0 +1,600 @@ +"""Run ledger store for Host-owned AgentRun and AgentRunEvent records.""" + +from __future__ import annotations + +import datetime +import json +import typing +import uuid + +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from ...entity.persistence.agent_run import AgentRun, AgentRunEvent, AgentRuntime + + +UTC = datetime.timezone.utc +TERMINAL_STATUSES = {'completed', 'failed', 'cancelled', 'timeout'} + + +def _utc_now() -> datetime.datetime: + return datetime.datetime.now(UTC) + + +def _datetime_to_epoch(value: datetime.datetime | None) -> int | None: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=UTC) + else: + value = value.astimezone(UTC) + return int(value.timestamp()) + + +def _epoch_to_datetime(value: typing.Any) -> datetime.datetime | None: + if value is None: + return None + try: + return datetime.datetime.fromtimestamp(float(value), UTC) + except (TypeError, ValueError, OSError): + return None + + +def _json_dumps(value: typing.Any) -> str | None: + if value is None: + return None + return json.dumps(value) + + +def _json_loads(value: str | None, default: typing.Any) -> typing.Any: + if not value: + return default + try: + return json.loads(value) + except (TypeError, ValueError): + return default + + +class RunLedgerStore: + """Store for Host-owned run lifecycle and result event facts.""" + + engine: AsyncEngine + + def __init__(self, engine: AsyncEngine): + self.engine = engine + self._session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async def create_run( + self, + *, + run_id: str, + event_id: str | None, + binding_id: str | None, + runner_id: str, + conversation_id: str | None = None, + thread_id: str | None = None, + workspace_id: str | None = None, + bot_id: str | None = None, + agent_id: str | None = None, + deadline_at: int | float | None = None, + authorization: dict[str, typing.Any] | None = None, + metadata: dict[str, typing.Any] | None = None, + status: str = 'running', + queue_name: str | None = None, + priority: int = 0, + requested_runtime_id: str | None = None, + ) -> dict[str, typing.Any]: + """Create a run if it does not already exist.""" + now = _utc_now() + async with self._session_factory() as session: + existing = await self._get_run_row(session, run_id) + if existing is not None: + return self._run_to_dict(existing) + + run = AgentRun( + run_id=run_id, + event_id=event_id, + agent_id=agent_id, + binding_id=binding_id, + runner_id=runner_id, + conversation_id=conversation_id, + thread_id=thread_id, + workspace_id=workspace_id, + bot_id=bot_id, + status=status, + queue_name=queue_name, + priority=priority, + requested_runtime_id=requested_runtime_id, + created_at=now, + started_at=now if status == 'running' else None, + updated_at=now, + deadline_at=_epoch_to_datetime(deadline_at), + authorization_json=_json_dumps(authorization), + metadata_json=_json_dumps(metadata), + ) + session.add(run) + await session.commit() + return self._run_to_dict(run) + + async def claim_next_run( + self, + *, + runtime_id: str, + queue_name: str | None = None, + lease_seconds: int = 60, + runner_ids: list[str] | None = None, + ) -> dict[str, typing.Any] | None: + """Claim the next queued or expired-leased run for a runtime.""" + now = _utc_now() + lease_expires_at = now + datetime.timedelta(seconds=max(int(lease_seconds), 1)) + async with self._session_factory() as session: + query = sqlalchemy.select(AgentRun).where( + sqlalchemy.or_( + AgentRun.status == 'queued', + sqlalchemy.and_( + AgentRun.status == 'claimed', + AgentRun.claim_lease_expires_at.is_not(None), + AgentRun.claim_lease_expires_at <= now, + ), + ), + sqlalchemy.or_( + AgentRun.requested_runtime_id.is_(None), + AgentRun.requested_runtime_id == runtime_id, + ), + ) + if queue_name is not None: + query = query.where(AgentRun.queue_name == queue_name) + if runner_ids: + query = query.where(AgentRun.runner_id.in_(runner_ids)) + + query = query.order_by(AgentRun.priority.desc(), AgentRun.id.asc()).limit(1).with_for_update( + skip_locked=True + ) + result = await session.execute(query) + run = result.scalars().first() + if run is None: + return None + + run.status = 'claimed' + run.claimed_by_runtime_id = runtime_id + run.claim_token = uuid.uuid4().hex + run.claim_lease_expires_at = lease_expires_at + run.dispatch_attempts = (run.dispatch_attempts or 0) + 1 + run.last_claimed_at = now + run.updated_at = now + await session.commit() + return self._run_to_dict(run) + + async def renew_claim( + self, + *, + run_id: str, + claim_token: str, + lease_seconds: int = 60, + ) -> dict[str, typing.Any] | None: + """Extend a current claim lease if the token still matches.""" + now = _utc_now() + async with self._session_factory() as session: + run = await self._get_run_row(session, run_id) + if run is None or run.status != 'claimed' or run.claim_token != claim_token: + return None + + run.claim_lease_expires_at = now + datetime.timedelta(seconds=max(int(lease_seconds), 1)) + run.updated_at = now + await session.commit() + return self._run_to_dict(run) + + async def release_claim( + self, + *, + run_id: str, + claim_token: str, + status: str = 'queued', + status_reason: str | None = None, + ) -> dict[str, typing.Any] | None: + """Release a current claim lease if the token still matches.""" + now = _utc_now() + async with self._session_factory() as session: + run = await self._get_run_row(session, run_id) + if run is None or run.status != 'claimed' or run.claim_token != claim_token: + return None + + run.status = status + run.status_reason = status_reason + run.claimed_by_runtime_id = None + run.claim_token = None + run.claim_lease_expires_at = None + run.updated_at = now + if status in TERMINAL_STATUSES: + run.finished_at = run.finished_at or now + await session.commit() + return self._run_to_dict(run) + + async def append_event( + self, + *, + run_id: str, + sequence: int, + event_type: str, + data: dict[str, typing.Any] | None = None, + usage: dict[str, typing.Any] | None = None, + source: str = 'runner', + artifact_refs: list[dict[str, typing.Any]] | None = None, + metadata: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any]: + """Append one run result event. + + If the same run_id + sequence already exists, the existing row is + returned. This supports retrying append calls idempotently. + """ + async with self._session_factory() as session: + result = await session.execute( + sqlalchemy.select(AgentRunEvent).where( + AgentRunEvent.run_id == run_id, + AgentRunEvent.sequence == sequence, + ) + ) + existing = result.scalars().first() + if existing is not None: + return self._event_to_dict(existing) + + row = AgentRunEvent( + run_id=run_id, + sequence=sequence, + type=event_type, + data_json=_json_dumps(data or {}), + usage_json=_json_dumps(usage), + created_at=_utc_now(), + source=source, + artifact_refs_json=_json_dumps(artifact_refs or []), + metadata_json=_json_dumps(metadata), + ) + session.add(row) + await session.commit() + return self._event_to_dict(row) + + async def finalize_run( + self, + *, + run_id: str, + status: str, + status_reason: str | None = None, + usage: dict[str, typing.Any] | None = None, + cost: dict[str, typing.Any] | None = None, + metadata: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any] | None: + """Update a run to a terminal or current status.""" + now = _utc_now() + async with self._session_factory() as session: + run = await self._get_run_row(session, run_id) + if run is None: + return None + + run.status = status + run.status_reason = status_reason + run.updated_at = now + if status in TERMINAL_STATUSES: + run.finished_at = run.finished_at or now + if usage is not None: + run.usage_json = _json_dumps(usage) + if cost is not None: + run.cost_json = _json_dumps(cost) + if metadata is not None: + existing_metadata = _json_loads(run.metadata_json, {}) + if isinstance(existing_metadata, dict): + existing_metadata.update(metadata) + run.metadata_json = _json_dumps(existing_metadata) + else: + run.metadata_json = _json_dumps(metadata) + await session.commit() + return self._run_to_dict(run) + + async def request_cancel( + self, + *, + run_id: str, + status_reason: str | None = None, + ) -> dict[str, typing.Any] | None: + """Record a cancellation request.""" + now = _utc_now() + async with self._session_factory() as session: + run = await self._get_run_row(session, run_id) + if run is None: + return None + run.cancel_requested_at = now + run.updated_at = now + run.status_reason = status_reason or run.status_reason + await session.commit() + return self._run_to_dict(run) + + async def get_run(self, run_id: str) -> dict[str, typing.Any] | None: + """Get one run by run_id.""" + async with self._session_factory() as session: + row = await self._get_run_row(session, run_id) + return self._run_to_dict(row) if row is not None else None + + async def register_runtime( + self, + *, + runtime_id: str, + status: str = 'online', + display_name: str | None = None, + endpoint: str | None = None, + version: str | None = None, + capabilities: dict[str, typing.Any] | None = None, + labels: dict[str, typing.Any] | None = None, + metadata: dict[str, typing.Any] | None = None, + heartbeat_deadline_seconds: int = 60, + ) -> dict[str, typing.Any]: + """Create or update a runtime registry row and record a heartbeat.""" + now = _utc_now() + async with self._session_factory() as session: + runtime = await self._get_runtime_row(session, runtime_id) + if runtime is None: + runtime = AgentRuntime(runtime_id=runtime_id, created_at=now) + session.add(runtime) + + runtime.status = status + runtime.display_name = display_name + runtime.endpoint = endpoint + runtime.version = version + runtime.capabilities_json = _json_dumps(capabilities or {}) + runtime.labels_json = _json_dumps(labels or {}) + runtime.metadata_json = _json_dumps(metadata or {}) + runtime.last_heartbeat_at = now + runtime.heartbeat_deadline_at = now + datetime.timedelta(seconds=max(int(heartbeat_deadline_seconds), 1)) + runtime.updated_at = now + await session.commit() + return self._runtime_to_dict(runtime) + + async def heartbeat_runtime( + self, + *, + runtime_id: str, + status: str = 'online', + heartbeat_deadline_seconds: int = 60, + capabilities: dict[str, typing.Any] | None = None, + labels: dict[str, typing.Any] | None = None, + metadata: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any] | None: + """Refresh a runtime heartbeat.""" + now = _utc_now() + async with self._session_factory() as session: + runtime = await self._get_runtime_row(session, runtime_id) + if runtime is None: + return None + + runtime.status = status + runtime.last_heartbeat_at = now + runtime.heartbeat_deadline_at = now + datetime.timedelta(seconds=max(int(heartbeat_deadline_seconds), 1)) + runtime.updated_at = now + if capabilities is not None: + runtime.capabilities_json = _json_dumps(capabilities) + if labels is not None: + runtime.labels_json = _json_dumps(labels) + if metadata is not None: + existing_metadata = _json_loads(runtime.metadata_json, {}) + if isinstance(existing_metadata, dict): + existing_metadata.update(metadata) + runtime.metadata_json = _json_dumps(existing_metadata) + else: + runtime.metadata_json = _json_dumps(metadata) + await session.commit() + return self._runtime_to_dict(runtime) + + async def get_runtime(self, runtime_id: str) -> dict[str, typing.Any] | None: + """Get one runtime by runtime_id.""" + async with self._session_factory() as session: + row = await self._get_runtime_row(session, runtime_id) + return self._runtime_to_dict(row) if row is not None else None + + async def list_runtimes( + self, + *, + statuses: list[str] | None = None, + limit: int = 100, + ) -> list[dict[str, typing.Any]]: + """List runtime registry rows.""" + limit = min(max(int(limit), 1), 500) + async with self._session_factory() as session: + query = sqlalchemy.select(AgentRuntime) + if statuses: + query = query.where(AgentRuntime.status.in_(statuses)) + query = query.order_by(AgentRuntime.id.asc()).limit(limit) + result = await session.execute(query) + return [self._runtime_to_dict(row) for row in result.scalars().all()] + + async def mark_stale_runtimes( + self, + *, + now: datetime.datetime | None = None, + stale_status: str = 'stale', + ) -> list[dict[str, typing.Any]]: + """Mark runtimes stale when their heartbeat deadline has passed.""" + current_time = now or _utc_now() + if current_time.tzinfo is None: + current_time = current_time.replace(tzinfo=UTC) + async with self._session_factory() as session: + result = await session.execute( + sqlalchemy.select(AgentRuntime).where( + AgentRuntime.heartbeat_deadline_at.is_not(None), + AgentRuntime.heartbeat_deadline_at < current_time, + AgentRuntime.status != stale_status, + ) + ) + runtimes = result.scalars().all() + for runtime in runtimes: + runtime.status = stale_status + runtime.updated_at = current_time + await session.commit() + return [self._runtime_to_dict(runtime) for runtime in runtimes] + + async def list_runs( + self, + *, + conversation_id: str | None = None, + statuses: list[str] | None = None, + before_id: int | None = None, + limit: int = 50, + bot_id: str | None = None, + workspace_id: str | None = None, + thread_id: str | None = None, + strict_thread: bool = False, + ) -> tuple[list[dict[str, typing.Any]], int | None, bool]: + """Page runs by scope.""" + limit = min(max(int(limit), 1), 100) + async with self._session_factory() as session: + query = sqlalchemy.select(AgentRun) + if conversation_id is not None: + query = query.where(AgentRun.conversation_id == conversation_id) + if statuses: + query = query.where(AgentRun.status.in_(statuses)) + if before_id is not None: + query = query.where(AgentRun.id < before_id) + query = self._apply_scope_filters(query, bot_id, workspace_id, thread_id, strict_thread) + query = query.order_by(AgentRun.id.desc()).limit(limit + 1) + + result = await session.execute(query) + rows = result.scalars().all() + items = [self._run_to_dict(row) for row in rows[:limit]] + has_more = len(rows) > limit + next_cursor = items[-1]['id'] if items and has_more else None + return items, next_cursor, has_more + + async def page_run_events( + self, + *, + run_id: str, + before_sequence: int | None = None, + after_sequence: int | None = None, + limit: int = 50, + direction: str = 'forward', + ) -> tuple[list[dict[str, typing.Any]], int | None, int | None, bool]: + """Page result events for one run.""" + limit = min(max(int(limit), 1), 100) + direction = direction if direction in {'forward', 'backward'} else 'forward' + async with self._session_factory() as session: + query = sqlalchemy.select(AgentRunEvent).where(AgentRunEvent.run_id == run_id) + if before_sequence is not None: + query = query.where(AgentRunEvent.sequence < before_sequence) + if after_sequence is not None: + query = query.where(AgentRunEvent.sequence > after_sequence) + + if direction == 'backward': + query = query.order_by(AgentRunEvent.sequence.desc()) + else: + query = query.order_by(AgentRunEvent.sequence.asc()) + query = query.limit(limit + 1) + + result = await session.execute(query) + rows = result.scalars().all() + items = [self._event_to_dict(row) for row in rows[:limit]] + has_more = len(rows) > limit + + if direction == 'backward': + next_cursor = items[-1]['sequence'] if items and has_more else None + prev_cursor = items[0]['sequence'] if items else None + else: + next_cursor = items[-1]['sequence'] if items and has_more else None + prev_cursor = items[0]['sequence'] if items else None + return items, next_cursor, prev_cursor, has_more + + async def _get_run_row( + self, + session: AsyncSession, + run_id: str, + ) -> AgentRun | None: + result = await session.execute(sqlalchemy.select(AgentRun).where(AgentRun.run_id == run_id)) + return result.scalars().first() + + async def _get_runtime_row( + self, + session: AsyncSession, + runtime_id: str, + ) -> AgentRuntime | None: + result = await session.execute(sqlalchemy.select(AgentRuntime).where(AgentRuntime.runtime_id == runtime_id)) + return result.scalars().first() + + def _apply_scope_filters( + self, + query: typing.Any, + bot_id: str | None, + workspace_id: str | None, + thread_id: str | None, + strict_thread: bool, + ) -> typing.Any: + if bot_id is not None: + query = query.where(AgentRun.bot_id == bot_id) + if workspace_id is not None: + query = query.where(AgentRun.workspace_id == workspace_id) + if strict_thread: + if thread_id is None: + query = query.where(AgentRun.thread_id.is_(None)) + else: + query = query.where(AgentRun.thread_id == thread_id) + return query + + def _run_to_dict(self, row: AgentRun) -> dict[str, typing.Any]: + return { + 'id': row.id, + 'run_id': row.run_id, + 'event_id': row.event_id, + 'agent_id': row.agent_id, + 'binding_id': row.binding_id, + 'runner_id': row.runner_id, + 'conversation_id': row.conversation_id, + 'thread_id': row.thread_id, + 'workspace_id': row.workspace_id, + 'bot_id': row.bot_id, + 'status': row.status, + 'status_reason': row.status_reason, + 'queue_name': row.queue_name, + 'priority': row.priority, + 'requested_runtime_id': row.requested_runtime_id, + 'claimed_by_runtime_id': row.claimed_by_runtime_id, + 'claim_token': row.claim_token, + 'claim_lease_expires_at': _datetime_to_epoch(row.claim_lease_expires_at), + 'dispatch_attempts': row.dispatch_attempts, + 'last_claimed_at': _datetime_to_epoch(row.last_claimed_at), + 'created_at': _datetime_to_epoch(row.created_at), + 'started_at': _datetime_to_epoch(row.started_at), + 'finished_at': _datetime_to_epoch(row.finished_at), + 'updated_at': _datetime_to_epoch(row.updated_at), + 'deadline_at': _datetime_to_epoch(row.deadline_at), + 'cancel_requested_at': _datetime_to_epoch(row.cancel_requested_at), + 'usage': _json_loads(row.usage_json, None), + 'cost': _json_loads(row.cost_json, None), + 'metadata': _json_loads(row.metadata_json, {}), + } + + def _runtime_to_dict(self, row: AgentRuntime) -> dict[str, typing.Any]: + return { + 'id': row.id, + 'runtime_id': row.runtime_id, + 'status': row.status, + 'display_name': row.display_name, + 'endpoint': row.endpoint, + 'version': row.version, + 'capabilities': _json_loads(row.capabilities_json, {}), + 'labels': _json_loads(row.labels_json, {}), + 'metadata': _json_loads(row.metadata_json, {}), + 'last_heartbeat_at': _datetime_to_epoch(row.last_heartbeat_at), + 'heartbeat_deadline_at': _datetime_to_epoch(row.heartbeat_deadline_at), + 'created_at': _datetime_to_epoch(row.created_at), + 'updated_at': _datetime_to_epoch(row.updated_at), + } + + def _event_to_dict(self, row: AgentRunEvent) -> dict[str, typing.Any]: + return { + 'id': row.id, + 'run_id': row.run_id, + 'sequence': row.sequence, + 'type': row.type, + 'data': _json_loads(row.data_json, {}), + 'usage': _json_loads(row.usage_json, None), + 'created_at': _datetime_to_epoch(row.created_at), + 'source': row.source, + 'artifact_refs': _json_loads(row.artifact_refs_json, []), + 'metadata': _json_loads(row.metadata_json, {}), + } diff --git a/src/langbot/pkg/entity/persistence/agent_run.py b/src/langbot/pkg/entity/persistence/agent_run.py new file mode 100644 index 00000000..e6cfa2a1 --- /dev/null +++ b/src/langbot/pkg/entity/persistence/agent_run.py @@ -0,0 +1,203 @@ +"""Agent run ledger persistence entities.""" + +from __future__ import annotations + +import datetime + +import sqlalchemy + +from .base import Base + + +class AgentRun(Base): + """AgentRun stores Host-owned execution lifecycle facts.""" + + __tablename__ = 'agent_run' + + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + """Auto-increment ID for pagination.""" + + run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True) + """Unique AgentRunner run identifier.""" + + event_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Input event that triggered this run.""" + + agent_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Future Host-owned agent identifier.""" + + binding_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Binding that selected this runner.""" + + runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True) + """Runner descriptor ID.""" + + conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Conversation this run belongs to.""" + + thread_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Thread this run belongs to.""" + + workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Workspace this run belongs to.""" + + bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Bot UUID this run belongs to.""" + + status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True) + """Run lifecycle status.""" + + status_reason = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Human-readable terminal or current status reason.""" + + queue_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Host queue name this run is waiting in.""" + + priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) + """Higher values are claimed before lower values within a queue.""" + + requested_runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Specific runtime requested by the producer, if any.""" + + claimed_by_runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Runtime that currently owns the claim lease.""" + + claim_token = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Opaque token required to renew or release the current claim.""" + + claim_lease_expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True) + """When the current claim lease expires.""" + + dispatch_attempts = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) + """Number of times this run has been claimed for dispatch.""" + + last_claimed_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """When this run was last claimed.""" + + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow) + """When the run record was created.""" + + started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """When execution started.""" + + finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """When execution reached a terminal status.""" + + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow + ) + """When the run record was last updated.""" + + deadline_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """Execution deadline if one was assigned.""" + + cancel_requested_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """When cancellation was requested.""" + + usage_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Final or latest aggregate token usage JSON.""" + + cost_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Host-calculated cost JSON, if available.""" + + authorization_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Run-scoped authorization snapshot JSON.""" + + metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Additional metadata JSON.""" + + __table_args__ = ( + sqlalchemy.Index( + 'ix_agent_run_scope_status', 'bot_id', 'workspace_id', 'conversation_id', 'thread_id', 'status' + ), + sqlalchemy.Index('ix_agent_run_runner_status', 'runner_id', 'status'), + sqlalchemy.Index('ix_agent_run_queue_claim', 'queue_name', 'status', 'priority', 'id'), + ) + + +class AgentRuntime(Base): + """AgentRuntime stores Host-owned runtime heartbeat registry facts.""" + + __tablename__ = 'agent_runtime' + + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + """Auto-increment ID.""" + + runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True) + """Unique runtime or daemon identifier.""" + + status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True) + """Runtime lifecycle status.""" + + display_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Human-readable runtime display name.""" + + endpoint = sqlalchemy.Column(sqlalchemy.String(1024), nullable=True) + """Runtime endpoint, if it exposes one.""" + + version = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Runtime version string.""" + + capabilities_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Runtime capabilities JSON.""" + + labels_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Runtime labels JSON.""" + + metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Additional metadata JSON.""" + + last_heartbeat_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True) + """When the runtime last sent a heartbeat.""" + + heartbeat_deadline_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True) + """When the runtime should be considered stale.""" + + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow) + """When the runtime record was created.""" + + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow + ) + """When the runtime record was last updated.""" + + +class AgentRunEvent(Base): + """AgentRunEvent stores one result event emitted by a run.""" + + __tablename__ = 'agent_run_event' + + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + """Auto-increment ID.""" + + run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True) + """Run that produced this event.""" + + sequence = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) + """Monotonic sequence inside the run.""" + + type = sqlalchemy.Column(sqlalchemy.String(100), nullable=False, index=True) + """Result event type.""" + + data_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Result event payload JSON.""" + + usage_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Token usage JSON for this event, if provided.""" + + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow) + """When this event was persisted.""" + + source = sqlalchemy.Column(sqlalchemy.String(50), nullable=True) + """Source that appended the event.""" + + artifact_refs_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Artifact references associated with this event.""" + + metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Additional metadata JSON.""" + + __table_args__ = ( + sqlalchemy.UniqueConstraint('run_id', 'sequence', name='uq_agent_run_event_run_sequence'), + sqlalchemy.Index('ix_agent_run_event_run_sequence', 'run_id', 'sequence'), + ) diff --git a/src/langbot/pkg/persistence/alembic/env.py b/src/langbot/pkg/persistence/alembic/env.py index 2ac48ed0..42a44d02 100644 --- a/src/langbot/pkg/persistence/alembic/env.py +++ b/src/langbot/pkg/persistence/alembic/env.py @@ -16,6 +16,7 @@ from langbot.pkg.entity.persistence.base import Base # Import all ORM models so they are registered with Base.metadata # This is required for autogenerate to detect model changes from langbot.pkg.entity.persistence import ( + agent_run, # noqa: F401 agent_runner_state, # noqa: F401 apikey, # noqa: F401 artifact, # noqa: F401 diff --git a/src/langbot/pkg/persistence/alembic/versions/8d3a1f2c4b6e_add_agent_run_ledger.py b/src/langbot/pkg/persistence/alembic/versions/8d3a1f2c4b6e_add_agent_run_ledger.py new file mode 100644 index 00000000..d658f7f8 --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/8d3a1f2c4b6e_add_agent_run_ledger.py @@ -0,0 +1,203 @@ +"""add agent run ledger + +Revision ID: 8d3a1f2c4b6e +Revises: 7b2c1d9e4f30 +Create Date: 2026-06-15 +""" + +from alembic import op +import sqlalchemy as sa + + +revision = '8d3a1f2c4b6e' +down_revision = '7b2c1d9e4f30' +branch_labels = None +depends_on = None + + +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _index_exists(table_name: str, index_name: str) -> bool: + return index_name in {index['name'] for index in sa.inspect(op.get_bind()).get_indexes(table_name)} + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in {column['name'] for column in sa.inspect(op.get_bind()).get_columns(table_name)} + + +def _add_column_if_missing(table_name: str, column: sa.Column) -> None: + if not _table_exists(table_name) or _column_exists(table_name, column.name): + return + with op.batch_alter_table(table_name, schema=None) as batch_op: + batch_op.add_column(column) + + +def _create_index_if_missing(table_name: str, index_name: str, columns: list[str], *, unique: bool = False) -> None: + if not _table_exists(table_name) or _index_exists(table_name, index_name): + return + existing_columns = {column['name'] for column in sa.inspect(op.get_bind()).get_columns(table_name)} + if not set(columns).issubset(existing_columns): + return + with op.batch_alter_table(table_name, schema=None) as batch_op: + batch_op.create_index(index_name, columns, unique=unique) + + +def _drop_index_if_exists(table_name: str, index_name: str) -> None: + if not _table_exists(table_name) or not _index_exists(table_name, index_name): + return + with op.batch_alter_table(table_name, schema=None) as batch_op: + batch_op.drop_index(index_name) + + +def upgrade() -> None: + if not _table_exists('agent_run'): + op.create_table( + 'agent_run', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('run_id', sa.String(255), nullable=False, unique=True), + sa.Column('event_id', sa.String(255), nullable=True), + sa.Column('agent_id', sa.String(255), nullable=True), + sa.Column('binding_id', sa.String(255), nullable=True), + sa.Column('runner_id', sa.String(255), nullable=False), + sa.Column('conversation_id', sa.String(255), nullable=True), + sa.Column('thread_id', sa.String(255), nullable=True), + sa.Column('workspace_id', sa.String(255), nullable=True), + sa.Column('bot_id', sa.String(255), nullable=True), + sa.Column('status', sa.String(50), nullable=False), + sa.Column('status_reason', sa.Text(), nullable=True), + sa.Column('queue_name', sa.String(255), nullable=True), + sa.Column('priority', sa.Integer(), nullable=False, server_default='0'), + sa.Column('requested_runtime_id', sa.String(255), nullable=True), + sa.Column('claimed_by_runtime_id', sa.String(255), nullable=True), + sa.Column('claim_token', sa.String(255), nullable=True), + sa.Column('claim_lease_expires_at', sa.DateTime(), nullable=True), + sa.Column('dispatch_attempts', sa.Integer(), nullable=False, server_default='0'), + sa.Column('last_claimed_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + sa.Column('deadline_at', sa.DateTime(), nullable=True), + sa.Column('cancel_requested_at', sa.DateTime(), nullable=True), + sa.Column('usage_json', sa.Text(), nullable=True), + sa.Column('cost_json', sa.Text(), nullable=True), + sa.Column('authorization_json', sa.Text(), nullable=True), + sa.Column('metadata_json', sa.Text(), nullable=True), + ) + else: + _add_column_if_missing('agent_run', sa.Column('queue_name', sa.String(255), nullable=True)) + _add_column_if_missing( + 'agent_run', sa.Column('priority', sa.Integer(), nullable=False, server_default='0') + ) + _add_column_if_missing('agent_run', sa.Column('requested_runtime_id', sa.String(255), nullable=True)) + _add_column_if_missing('agent_run', sa.Column('claimed_by_runtime_id', sa.String(255), nullable=True)) + _add_column_if_missing('agent_run', sa.Column('claim_token', sa.String(255), nullable=True)) + _add_column_if_missing('agent_run', sa.Column('claim_lease_expires_at', sa.DateTime(), nullable=True)) + _add_column_if_missing( + 'agent_run', sa.Column('dispatch_attempts', sa.Integer(), nullable=False, server_default='0') + ) + _add_column_if_missing('agent_run', sa.Column('last_claimed_at', sa.DateTime(), nullable=True)) + + _create_index_if_missing('agent_run', 'ix_agent_run_run_id', ['run_id'], unique=True) + _create_index_if_missing('agent_run', 'ix_agent_run_event_id', ['event_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_binding_id', ['binding_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_runner_id', ['runner_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_conversation_id', ['conversation_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_bot_id', ['bot_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_status', ['status']) + _create_index_if_missing('agent_run', 'ix_agent_run_queue_name', ['queue_name']) + _create_index_if_missing('agent_run', 'ix_agent_run_requested_runtime_id', ['requested_runtime_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_claimed_by_runtime_id', ['claimed_by_runtime_id']) + _create_index_if_missing('agent_run', 'ix_agent_run_claim_token', ['claim_token']) + _create_index_if_missing('agent_run', 'ix_agent_run_claim_lease_expires_at', ['claim_lease_expires_at']) + _create_index_if_missing( + 'agent_run', + 'ix_agent_run_scope_status', + ['bot_id', 'workspace_id', 'conversation_id', 'thread_id', 'status'], + ) + _create_index_if_missing('agent_run', 'ix_agent_run_runner_status', ['runner_id', 'status']) + _create_index_if_missing('agent_run', 'ix_agent_run_queue_claim', ['queue_name', 'status', 'priority', 'id']) + + if not _table_exists('agent_run_event'): + op.create_table( + 'agent_run_event', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('run_id', sa.String(255), nullable=False), + sa.Column('sequence', sa.Integer(), nullable=False), + sa.Column('type', sa.String(100), nullable=False), + sa.Column('data_json', sa.Text(), nullable=True), + sa.Column('usage_json', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + sa.Column('source', sa.String(50), nullable=True), + sa.Column('artifact_refs_json', sa.Text(), nullable=True), + sa.Column('metadata_json', sa.Text(), nullable=True), + sa.UniqueConstraint('run_id', 'sequence', name='uq_agent_run_event_run_sequence'), + ) + + _create_index_if_missing('agent_run_event', 'ix_agent_run_event_run_id', ['run_id']) + _create_index_if_missing('agent_run_event', 'ix_agent_run_event_type', ['type']) + _create_index_if_missing( + 'agent_run_event', + 'ix_agent_run_event_run_sequence', + ['run_id', 'sequence'], + ) + + if not _table_exists('agent_runtime'): + op.create_table( + 'agent_runtime', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('runtime_id', sa.String(255), nullable=False, unique=True), + sa.Column('status', sa.String(50), nullable=False), + sa.Column('display_name', sa.String(255), nullable=True), + sa.Column('endpoint', sa.String(1024), nullable=True), + sa.Column('version', sa.String(255), nullable=True), + sa.Column('capabilities_json', sa.Text(), nullable=True), + sa.Column('labels_json', sa.Text(), nullable=True), + sa.Column('metadata_json', sa.Text(), nullable=True), + sa.Column('last_heartbeat_at', sa.DateTime(), nullable=True), + sa.Column('heartbeat_deadline_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + ) + + _create_index_if_missing('agent_runtime', 'ix_agent_runtime_runtime_id', ['runtime_id'], unique=True) + _create_index_if_missing('agent_runtime', 'ix_agent_runtime_status', ['status']) + _create_index_if_missing('agent_runtime', 'ix_agent_runtime_last_heartbeat_at', ['last_heartbeat_at']) + _create_index_if_missing('agent_runtime', 'ix_agent_runtime_heartbeat_deadline_at', ['heartbeat_deadline_at']) + + +def downgrade() -> None: + _drop_index_if_exists('agent_runtime', 'ix_agent_runtime_heartbeat_deadline_at') + _drop_index_if_exists('agent_runtime', 'ix_agent_runtime_last_heartbeat_at') + _drop_index_if_exists('agent_runtime', 'ix_agent_runtime_status') + _drop_index_if_exists('agent_runtime', 'ix_agent_runtime_runtime_id') + if _table_exists('agent_runtime'): + op.drop_table('agent_runtime') + + _drop_index_if_exists('agent_run_event', 'ix_agent_run_event_run_sequence') + _drop_index_if_exists('agent_run_event', 'ix_agent_run_event_type') + _drop_index_if_exists('agent_run_event', 'ix_agent_run_event_run_id') + if _table_exists('agent_run_event'): + op.drop_table('agent_run_event') + + _drop_index_if_exists('agent_run', 'ix_agent_run_queue_claim') + _drop_index_if_exists('agent_run', 'ix_agent_run_claim_lease_expires_at') + _drop_index_if_exists('agent_run', 'ix_agent_run_claim_token') + _drop_index_if_exists('agent_run', 'ix_agent_run_claimed_by_runtime_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_requested_runtime_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_queue_name') + _drop_index_if_exists('agent_run', 'ix_agent_run_runner_status') + _drop_index_if_exists('agent_run', 'ix_agent_run_scope_status') + _drop_index_if_exists('agent_run', 'ix_agent_run_status') + _drop_index_if_exists('agent_run', 'ix_agent_run_bot_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_conversation_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_runner_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_binding_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_event_id') + _drop_index_if_exists('agent_run', 'ix_agent_run_run_id') + if _table_exists('agent_run'): + op.drop_table('agent_run') diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index a9a5178b..2e8a4700 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -3,6 +3,8 @@ from __future__ import annotations import typing from typing import Any, Union import base64 +import json +import time import traceback import sqlalchemy @@ -30,6 +32,28 @@ from ..agent.runner.config_migration import ConfigMigration from ..agent.runner import config_schema +class _RuntimeActionName: + def __init__(self, value: str): + self.value = value + + +def _plugin_runtime_action(name: str, value: str) -> Any: + return getattr(PluginToRuntimeAction, name, _RuntimeActionName(value)) + + +def _deadline_seconds_from_payload(data: dict[str, Any], default: int = 60) -> int: + deadline_at = data.get('heartbeat_deadline_at') + if deadline_at is not None: + try: + return max(int(float(deadline_at) - time.time()), 1) + except (TypeError, ValueError): + pass + try: + return max(int(data.get('heartbeat_ttl_seconds') or default), 1) + except (TypeError, ValueError): + return default + + def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse: """Create a clean error response for RAG operations. @@ -152,11 +176,16 @@ def _validate_artifact_access( # Rule 2: Same conversation (requires artifact to have conversation_id) if artifact_conversation_id and session_conversation_id: - if artifact_conversation_id == session_conversation_id and _artifact_matches_run_scope(session, artifact_metadata): + if artifact_conversation_id == session_conversation_id and _artifact_matches_run_scope( + session, artifact_metadata + ): return True, None # Rule 3: Deny - no matching authorization rule - return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run' + return ( + False, + f'Artifact {operation} access denied: artifact not in session conversation and not created by this run', + ) def _get_run_authorization(session: dict[str, Any]) -> dict[str, Any]: @@ -179,6 +208,24 @@ def _public_artifact_metadata(artifact_metadata: dict[str, Any]) -> dict[str, An return public_metadata +def _run_matches_run_scope(session: dict[str, Any], run: dict[str, Any]) -> bool: + authorization = _get_run_authorization(session) + session_run_id = session.get('run_id') + if run.get('run_id') == session_run_id: + return True + if not authorization.get('conversation_id'): + return False + if run.get('conversation_id') != authorization.get('conversation_id'): + return False + if authorization.get('bot_id') is not None and authorization.get('bot_id') != run.get('bot_id'): + return False + if authorization.get('workspace_id') is not None and authorization.get('workspace_id') != run.get('workspace_id'): + return False + if authorization.get('thread_id') != run.get('thread_id'): + return False + return True + + def _resolve_state_scope( session: dict[str, Any], scope: str, @@ -188,22 +235,16 @@ def _resolve_state_scope( state_policy = authorization['state_policy'] if not state_policy.get('enable_state', True): - return None, None, handler.ActionResponse.error( - message='State access is disabled by binding policy' - ) + return None, None, handler.ActionResponse.error(message='State access is disabled by binding policy') state_scopes = state_policy.get('state_scopes', ['conversation', 'actor']) if scope not in state_scopes: - return None, None, handler.ActionResponse.error( - message=f'Scope "{scope}" is not enabled by binding policy' - ) + return None, None, handler.ActionResponse.error(message=f'Scope "{scope}" is not enabled by binding policy') state_context = authorization['state_context'] scope_key = state_context.get('scope_keys', {}).get(scope) if not scope_key: - return None, None, handler.ActionResponse.error( - message=f'Scope key not available for scope "{scope}"' - ) + return None, None, handler.ActionResponse.error(message=f'Scope key not available for scope "{scope}"') return state_context, scope_key, None @@ -214,44 +255,83 @@ async def _validate_agent_run_session( ap: app.Application, api_name: str, api_capability: str | None = None, + allow_persistent_authorization: bool = False, ) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]: """Validate an AgentRunner pull API run session and run-scoped API access.""" session_registry = get_session_registry() session = await session_registry.get(run_id) if not session: - return None, handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) + if allow_persistent_authorization: + session = await _load_persistent_agent_run_session(run_id, ap, api_name) + if not session: + return None, handler.ActionResponse.error(message=f'Run session {run_id} not found or expired') session_plugin_identity = session.get('plugin_identity') if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip(): ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity') - return None, handler.ActionResponse.error( - message=f'Run session {run_id} has no plugin_identity' - ) + return None, handler.ActionResponse.error(message=f'Run session {run_id} has no plugin_identity') if not caller_plugin_identity: - return None, handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) + return None, handler.ActionResponse.error(message=f'caller_plugin_identity is required for run_id {run_id}') if caller_plugin_identity != session_plugin_identity: ap.logger.warning( f'{api_name}: caller_plugin_identity {caller_plugin_identity} ' f'does not match session plugin_identity {session_plugin_identity}' ) - return None, handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) + return None, handler.ActionResponse.error(message=f'Plugin identity mismatch for run_id {run_id}') if api_capability: available_apis = _get_run_authorization(session).get('available_apis', {}) if not available_apis.get(api_capability, False): - return None, handler.ActionResponse.error( - message=f'{api_name} access not authorized' - ) + return None, handler.ActionResponse.error(message=f'{api_name} access not authorized') return session, None +async def _load_persistent_agent_run_session( + run_id: str, + ap: app.Application, + api_name: str, +) -> dict[str, Any] | None: + """Load an expired run session from the AgentRun authorization snapshot.""" + try: + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import sessionmaker + + from ..entity.persistence.agent_run import AgentRun + + engine = ap.persistence_mgr.get_db_engine() + session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with session_factory() as db_session: + result = await db_session.execute(sqlalchemy.select(AgentRun).where(AgentRun.run_id == run_id)) + run = result.scalars().first() + except Exception as e: + ap.logger.error(f'{api_name}: failed to load persistent authorization for run_id {run_id}: {e}', exc_info=True) + return None + + if run is None: + return None + + try: + authorization = json.loads(run.authorization_json) if run.authorization_json else {} + except (TypeError, ValueError) as e: + ap.logger.warning(f'{api_name}: run_id {run_id} has invalid authorization_json: {e}') + return None + + if not isinstance(authorization, dict): + ap.logger.warning(f'{api_name}: run_id {run_id} authorization_json is not an object') + return None + + return { + 'run_id': run.run_id, + 'runner_id': authorization.get('runner_id') or run.runner_id, + 'query_id': None, + 'plugin_identity': authorization.get('plugin_identity'), + 'authorization': authorization, + 'status': {}, + 'steering_queue': [], + } + + def _resolve_run_conversation( session: dict[str, Any], requested_conversation_id: str | None, @@ -262,9 +342,7 @@ def _resolve_run_conversation( if requested_conversation_id: if not session_conversation_id: - return None, handler.ActionResponse.error( - message=f'{api_name} is not available without a run conversation' - ) + return None, handler.ActionResponse.error(message=f'{api_name} is not available without a run conversation') if requested_conversation_id != session_conversation_id: return None, handler.ActionResponse.error( message=f'Conversation {requested_conversation_id} is not accessible by this run' @@ -329,9 +407,7 @@ def _normalize_uuid_list(values: Any) -> list[str]: if not isinstance(values, list): return [] return list( - dict.fromkeys( - value for value in values if isinstance(value, str) and value not in config_schema.NONE_SENTINELS - ) + dict.fromkeys(value for value in values if isinstance(value, str) and value not in config_schema.NONE_SENTINELS) ) @@ -393,18 +469,14 @@ async def _validate_run_authorization( session_registry = get_session_registry() session = await session_registry.get(run_id) if not session: - ap.logger.warning( - f'{resource_type.upper()}: run_id {run_id} not found in session registry' - ) + ap.logger.warning(f'{resource_type.upper()}: run_id {run_id} not found in session registry') return None, handler.ActionResponse.error( message=f'Run session {run_id} not found or expired', ) session_plugin_identity = session.get('plugin_identity') if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip(): - ap.logger.warning( - f'{resource_type.upper()}: run_id {run_id} has no plugin_identity' - ) + ap.logger.warning(f'{resource_type.upper()}: run_id {run_id} has no plugin_identity') return None, handler.ActionResponse.error( message=f'Run session {run_id} has no plugin_identity', ) @@ -1573,12 +1645,14 @@ class RuntimeConnectionHandler(handler.Handler): prompt = getattr(query, 'prompt', None) messages = getattr(prompt, 'messages', []) or [] - return handler.ActionResponse.success(data={ - 'prompt': [ - message.model_dump(mode='json') if hasattr(message, 'model_dump') else message - for message in messages - ], - }) + return handler.ActionResponse.success( + data={ + 'prompt': [ + message.model_dump(mode='json') if hasattr(message, 'model_dump') else message + for message in messages + ], + } + ) @self.action(PluginToRuntimeAction.HISTORY_PAGE) async def history_page(data: dict[str, Any]) -> handler.ActionResponse: @@ -1617,12 +1691,14 @@ class RuntimeConnectionHandler(handler.Handler): return scope_error if not conversation_id: - return handler.ActionResponse.success(data={ - 'items': [], - 'next_cursor': None, - 'prev_cursor': None, - 'has_more': False, - }) + return handler.ActionResponse.success( + data={ + 'items': [], + 'next_cursor': None, + 'prev_cursor': None, + 'has_more': False, + } + ) # Parse cursors before_seq = int(before_cursor) if before_cursor else None @@ -1630,6 +1706,7 @@ class RuntimeConnectionHandler(handler.Handler): # Query transcript from ..agent.runner.transcript_store import TranscriptStore + store = TranscriptStore(self.ap.persistence_mgr.get_db_engine()) try: @@ -1643,12 +1720,14 @@ class RuntimeConnectionHandler(handler.Handler): **_run_scope_filters(session), ) - return handler.ActionResponse.success(data={ - 'items': items, - 'next_cursor': str(next_seq) if next_seq else None, - 'prev_cursor': str(prev_seq) if prev_seq else None, - 'has_more': has_more, - }) + return handler.ActionResponse.success( + data={ + 'items': items, + 'next_cursor': str(next_seq) if next_seq else None, + 'prev_cursor': str(prev_seq) if prev_seq else None, + 'has_more': has_more, + } + ) except Exception as e: self.ap.logger.error(f'HISTORY_PAGE error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'History page error: {e}') @@ -1689,14 +1768,17 @@ class RuntimeConnectionHandler(handler.Handler): return scope_error if not conversation_id: - return handler.ActionResponse.success(data={ - 'items': [], - 'total_count': 0, - 'query': query_text, - }) + return handler.ActionResponse.success( + data={ + 'items': [], + 'total_count': 0, + 'query': query_text, + } + ) # Search transcript from ..agent.runner.transcript_store import TranscriptStore + store = TranscriptStore(self.ap.persistence_mgr.get_db_engine()) try: @@ -1709,11 +1791,13 @@ class RuntimeConnectionHandler(handler.Handler): **_run_scope_filters(session), ) - return handler.ActionResponse.success(data={ - 'items': items, - 'total_count': len(items), - 'query': query_text, - }) + return handler.ActionResponse.success( + data={ + 'items': items, + 'total_count': len(items), + 'query': query_text, + } + ) except Exception as e: self.ap.logger.error(f'HISTORY_SEARCH error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'History search error: {e}') @@ -1746,14 +1830,13 @@ class RuntimeConnectionHandler(handler.Handler): # Get event from ..agent.runner.event_log_store import EventLogStore + store = EventLogStore(self.ap.persistence_mgr.get_db_engine()) try: event = await store.get_event(event_id) if not event: - return handler.ActionResponse.error( - message=f'Event {event_id} not found' - ) + return handler.ActionResponse.error(message=f'Event {event_id} not found') # Validate event is in the same conversation as the run, or was created by the same run. session_conversation_id = _get_run_authorization(session).get('conversation_id') @@ -1761,9 +1844,7 @@ class RuntimeConnectionHandler(handler.Handler): if event_run_id and event_run_id == run_id: return handler.ActionResponse.success(data=_project_event_record_for_api(event)) if not session_conversation_id or not _event_matches_run_scope(session, event): - return handler.ActionResponse.error( - message=f'Event {event_id} is not accessible by this run' - ) + return handler.ActionResponse.error(message=f'Event {event_id} is not accessible by this run') return handler.ActionResponse.success(data=_project_event_record_for_api(event)) except Exception as e: @@ -1805,18 +1886,21 @@ class RuntimeConnectionHandler(handler.Handler): return scope_error if not conversation_id: - return handler.ActionResponse.success(data={ - 'items': [], - 'next_cursor': None, - 'prev_cursor': None, - 'has_more': False, - }) + return handler.ActionResponse.success( + data={ + 'items': [], + 'next_cursor': None, + 'prev_cursor': None, + 'has_more': False, + } + ) # Parse cursor before_seq = int(before_cursor) if before_cursor else None # Query events from ..agent.runner.event_log_store import EventLogStore + store = EventLogStore(self.ap.persistence_mgr.get_db_engine()) try: @@ -1828,16 +1912,630 @@ class RuntimeConnectionHandler(handler.Handler): **_run_scope_filters(session), ) - return handler.ActionResponse.success(data={ - 'items': [_project_event_record_for_api(item) for item in items], - 'next_cursor': str(next_seq) if next_seq else None, - 'prev_cursor': None, - 'has_more': has_more, - }) + return handler.ActionResponse.success( + data={ + 'items': [_project_event_record_for_api(item) for item in items], + 'next_cursor': str(next_seq) if next_seq else None, + 'prev_cursor': None, + 'has_more': has_more, + } + ) except Exception as e: self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'Event page error: {e}') + @self.action(_plugin_runtime_action('RUN_GET', 'run_get')) + async def run_get(data: dict[str, Any]) -> handler.ActionResponse: + """Get one Host-owned run record visible to the current run.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') or run_id + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run get', + api_capability='run_get', + allow_persistent_authorization=True, + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.get_run(str(target_run_id)) + if not run: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + if not _run_matches_run_scope(session, run): + return handler.ActionResponse.error(message=f'Run {target_run_id} is not accessible by this run') + return handler.ActionResponse.success(data=run) + except Exception as e: + self.ap.logger.error(f'RUN_GET error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run get error: {e}') + + @self.action(_plugin_runtime_action('RUN_LIST', 'run_list')) + async def run_list(data: dict[str, Any]) -> handler.ActionResponse: + """List Host-owned runs visible to the current run conversation.""" + run_id = data.get('run_id') + conversation_id = data.get('conversation_id') + statuses = data.get('statuses') + before_cursor = data.get('before_cursor') + limit = data.get('limit', 50) + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run list', + api_capability='run_list', + allow_persistent_authorization=True, + ) + if error: + return error + + conversation_id, scope_error = _resolve_run_conversation( + session, + conversation_id, + 'Run list', + ) + if scope_error: + return scope_error + + if not conversation_id: + return handler.ActionResponse.success( + data={ + 'items': [], + 'next_cursor': None, + 'prev_cursor': None, + 'has_more': False, + } + ) + + if statuses is not None and not isinstance(statuses, list): + return handler.ActionResponse.error(message='statuses must be a list') + try: + before_id = int(before_cursor) if before_cursor else None + except (TypeError, ValueError): + return handler.ActionResponse.error(message='before_cursor must be an integer cursor') + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + items, next_cursor, has_more = await store.list_runs( + conversation_id=conversation_id, + statuses=[str(status) for status in statuses] if statuses else None, + before_id=before_id, + limit=limit, + **_run_scope_filters(session), + ) + return handler.ActionResponse.success( + data={ + 'items': items, + 'next_cursor': str(next_cursor) if next_cursor else None, + 'prev_cursor': None, + 'has_more': has_more, + } + ) + except Exception as e: + self.ap.logger.error(f'RUN_LIST error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run list error: {e}') + + @self.action(_plugin_runtime_action('RUN_EVENTS_PAGE', 'run_events_page')) + async def run_events_page(data: dict[str, Any]) -> handler.ActionResponse: + """Page result events for one Host-owned run visible to current run.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') or run_id + before_cursor = data.get('before_cursor') + after_cursor = data.get('after_cursor') + limit = data.get('limit', 50) + direction = data.get('direction', 'forward') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run events page', + api_capability='run_events_page', + allow_persistent_authorization=True, + ) + if error: + return error + + try: + before_sequence = int(before_cursor) if before_cursor else None + after_sequence = int(after_cursor) if after_cursor else None + except (TypeError, ValueError): + return handler.ActionResponse.error(message='run event cursors must be integer sequences') + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.get_run(str(target_run_id)) + if not run: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + if not _run_matches_run_scope(session, run): + return handler.ActionResponse.error(message=f'Run {target_run_id} is not accessible by this run') + + items, next_cursor, prev_cursor, has_more = await store.page_run_events( + run_id=str(target_run_id), + before_sequence=before_sequence, + after_sequence=after_sequence, + limit=limit, + direction=str(direction or 'forward'), + ) + return handler.ActionResponse.success( + data={ + 'items': items, + 'next_cursor': str(next_cursor) if next_cursor else None, + 'prev_cursor': str(prev_cursor) if prev_cursor else None, + 'has_more': has_more, + } + ) + except Exception as e: + self.ap.logger.error(f'RUN_EVENTS_PAGE error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run events page error: {e}') + + @self.action(_plugin_runtime_action('RUN_CANCEL', 'run_cancel')) + async def run_cancel(data: dict[str, Any]) -> handler.ActionResponse: + """Request cancellation for one Host-owned run visible to the current run.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') or run_id + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run cancel', + api_capability='run_cancel', + allow_persistent_authorization=True, + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.get_run(str(target_run_id)) + if not run: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + if not _run_matches_run_scope(session, run): + return handler.ActionResponse.error(message=f'Run {target_run_id} is not accessible by this run') + + updated = await store.request_cancel( + run_id=str(target_run_id), + status_reason=data.get('status_reason') or data.get('reason'), + ) + if not updated: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + return handler.ActionResponse.success(data=updated) + except Exception as e: + self.ap.logger.error(f'RUN_CANCEL error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run cancel error: {e}') + + @self.action(_plugin_runtime_action('RUN_APPEND_RESULT', 'run_append_result')) + async def run_append_result(data: dict[str, Any]) -> handler.ActionResponse: + """Append one result event for a Host-owned run visible to the current run.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') or run_id + caller_plugin_identity = data.get('caller_plugin_identity') + result = data.get('result') if isinstance(data.get('result'), dict) else {} + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + + try: + sequence = int(data.get('sequence') or result.get('sequence')) + except (TypeError, ValueError): + return handler.ActionResponse.error(message='sequence is required and must be an integer') + + event_type = data.get('event_type') or data.get('type') or result.get('type') + if not event_type: + return handler.ActionResponse.error(message='event_type is required') + + event_data = data.get('data') if isinstance(data.get('data'), dict) else result.get('data') + usage = data.get('usage') if isinstance(data.get('usage'), dict) else result.get('usage') + artifact_refs = data.get('artifact_refs') if isinstance(data.get('artifact_refs'), list) else None + metadata = data.get('metadata') if isinstance(data.get('metadata'), dict) else None + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run append result', + api_capability='run_append_result', + allow_persistent_authorization=True, + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.get_run(str(target_run_id)) + if not run: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + if not _run_matches_run_scope(session, run): + return handler.ActionResponse.error(message=f'Run {target_run_id} is not accessible by this run') + + event = await store.append_event( + run_id=str(target_run_id), + sequence=sequence, + event_type=str(event_type), + data=event_data if isinstance(event_data, dict) else {}, + usage=usage if isinstance(usage, dict) else None, + source=str(data.get('source') or result.get('source') or 'runner'), + artifact_refs=artifact_refs, + metadata=metadata, + ) + return handler.ActionResponse.success(data=event) + except Exception as e: + self.ap.logger.error(f'RUN_APPEND_RESULT error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run append result error: {e}') + + @self.action(_plugin_runtime_action('RUN_FINALIZE', 'run_finalize')) + async def run_finalize(data: dict[str, Any]) -> handler.ActionResponse: + """Finalize one Host-owned run visible to the current run.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') or run_id + caller_plugin_identity = data.get('caller_plugin_identity') + status = data.get('status') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + if not status: + return handler.ActionResponse.error(message='status is required') + + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run finalize', + api_capability='run_finalize', + allow_persistent_authorization=True, + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.get_run(str(target_run_id)) + if not run: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + if not _run_matches_run_scope(session, run): + return handler.ActionResponse.error(message=f'Run {target_run_id} is not accessible by this run') + + updated = await store.finalize_run( + run_id=str(target_run_id), + status=str(status), + status_reason=data.get('status_reason') or data.get('reason'), + usage=data.get('usage') if isinstance(data.get('usage'), dict) else None, + cost=data.get('cost') if isinstance(data.get('cost'), dict) else None, + metadata=data.get('metadata') if isinstance(data.get('metadata'), dict) else None, + ) + if not updated: + return handler.ActionResponse.error(message=f'Run {target_run_id} not found') + return handler.ActionResponse.success(data=updated) + except Exception as e: + self.ap.logger.error(f'RUN_FINALIZE error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run finalize error: {e}') + + @self.action(_plugin_runtime_action('RUNTIME_REGISTER', 'runtime_register')) + async def runtime_register(data: dict[str, Any]) -> handler.ActionResponse: + """Register or update one Host-owned runtime registry record.""" + run_id = data.get('run_id') + runtime_id = data.get('runtime_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not runtime_id: + return handler.ActionResponse.error(message='runtime_id is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Runtime register', + api_capability='runtime_register', + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + runtime = await store.register_runtime( + runtime_id=str(runtime_id), + status=str(data.get('status') or 'online'), + display_name=data.get('display_name'), + endpoint=data.get('endpoint'), + version=data.get('version'), + capabilities=data.get('capabilities') if isinstance(data.get('capabilities'), dict) else {}, + labels=data.get('labels') if isinstance(data.get('labels'), dict) else {}, + metadata=data.get('metadata') if isinstance(data.get('metadata'), dict) else {}, + heartbeat_deadline_seconds=_deadline_seconds_from_payload(data), + ) + return handler.ActionResponse.success(data=runtime) + except Exception as e: + self.ap.logger.error(f'RUNTIME_REGISTER error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Runtime register error: {e}') + + @self.action(_plugin_runtime_action('RUNTIME_HEARTBEAT', 'runtime_heartbeat')) + async def runtime_heartbeat(data: dict[str, Any]) -> handler.ActionResponse: + """Refresh one Host-owned runtime heartbeat.""" + run_id = data.get('run_id') + runtime_id = data.get('runtime_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not runtime_id: + return handler.ActionResponse.error(message='runtime_id is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Runtime heartbeat', + api_capability='runtime_heartbeat', + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + runtime = await store.heartbeat_runtime( + runtime_id=str(runtime_id), + status=str(data.get('status') or 'online'), + capabilities=data.get('capabilities') if isinstance(data.get('capabilities'), dict) else None, + labels=data.get('labels') if isinstance(data.get('labels'), dict) else None, + metadata=data.get('metadata') if isinstance(data.get('metadata'), dict) else None, + heartbeat_deadline_seconds=_deadline_seconds_from_payload(data), + ) + if runtime is None: + return handler.ActionResponse.error(message=f'Runtime {runtime_id} not found') + return handler.ActionResponse.success(data=runtime) + except Exception as e: + self.ap.logger.error(f'RUNTIME_HEARTBEAT error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Runtime heartbeat error: {e}') + + @self.action(_plugin_runtime_action('RUNTIME_LIST', 'runtime_list')) + async def runtime_list(data: dict[str, Any]) -> handler.ActionResponse: + """List Host-owned runtime registry records.""" + run_id = data.get('run_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Runtime list', + api_capability='runtime_list', + ) + if error: + return error + + statuses = data.get('statuses') + if statuses is not None and not isinstance(statuses, list): + return handler.ActionResponse.error(message='statuses must be a list') + labels = data.get('labels') if isinstance(data.get('labels'), dict) else {} + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + runtimes = await store.list_runtimes( + statuses=[str(status) for status in statuses] if statuses else None, + limit=data.get('limit', 50), + ) + if labels: + runtimes = [ + runtime + for runtime in runtimes + if all(runtime.get('labels', {}).get(key) == value for key, value in labels.items()) + ] + return handler.ActionResponse.success( + data={ + 'items': runtimes, + 'next_cursor': None, + 'prev_cursor': None, + 'has_more': False, + } + ) + except Exception as e: + self.ap.logger.error(f'RUNTIME_LIST error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Runtime list error: {e}') + + @self.action(_plugin_runtime_action('RUN_CLAIM', 'run_claim')) + async def run_claim(data: dict[str, Any]) -> handler.ActionResponse: + """Claim one queued run for a runtime lease.""" + run_id = data.get('run_id') + runtime_id = data.get('runtime_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not runtime_id: + return handler.ActionResponse.error(message='runtime_id is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run claim', + api_capability='run_claim', + ) + if error: + return error + + runner_ids = data.get('runner_ids') + if runner_ids is not None and not isinstance(runner_ids, list): + return handler.ActionResponse.error(message='runner_ids must be a list') + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + run = await store.claim_next_run( + runtime_id=str(runtime_id), + queue_name=data.get('queue_name'), + lease_seconds=data.get('lease_seconds', 60), + runner_ids=[str(item) for item in runner_ids] if runner_ids else None, + ) + if run is None: + return handler.ActionResponse.error(message='No queued run available') + return handler.ActionResponse.success(data=run) + except Exception as e: + self.ap.logger.error(f'RUN_CLAIM error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run claim error: {e}') + + @self.action(_plugin_runtime_action('RUN_RENEW_CLAIM', 'run_renew_claim')) + async def run_renew_claim(data: dict[str, Any]) -> handler.ActionResponse: + """Renew one run claim lease.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') + runtime_id = data.get('runtime_id') + claim_token = data.get('claim_token') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + if not runtime_id: + return handler.ActionResponse.error(message='runtime_id is required') + if not claim_token: + return handler.ActionResponse.error(message='claim_token is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run renew claim', + api_capability='run_renew_claim', + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + current = await store.get_run(str(target_run_id)) + if not current or current.get('claimed_by_runtime_id') != runtime_id: + return handler.ActionResponse.error(message=f'Run claim {target_run_id} not found') + run = await store.renew_claim( + run_id=str(target_run_id), + claim_token=str(claim_token), + lease_seconds=data.get('lease_seconds', 60), + ) + if run is None: + return handler.ActionResponse.error(message=f'Run claim {target_run_id} not found') + return handler.ActionResponse.success(data=run) + except Exception as e: + self.ap.logger.error(f'RUN_RENEW_CLAIM error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run renew claim error: {e}') + + @self.action(_plugin_runtime_action('RUN_RELEASE_CLAIM', 'run_release_claim')) + async def run_release_claim(data: dict[str, Any]) -> handler.ActionResponse: + """Release one run claim lease.""" + run_id = data.get('run_id') + target_run_id = data.get('target_run_id') + runtime_id = data.get('runtime_id') + claim_token = data.get('claim_token') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + if not target_run_id: + return handler.ActionResponse.error(message='target_run_id is required') + if not runtime_id: + return handler.ActionResponse.error(message='runtime_id is required') + if not claim_token: + return handler.ActionResponse.error(message='claim_token is required') + + _session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Run release claim', + api_capability='run_release_claim', + ) + if error: + return error + + from ..agent.runner.run_ledger_store import RunLedgerStore + + store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine()) + + try: + current = await store.get_run(str(target_run_id)) + if not current or current.get('claimed_by_runtime_id') != runtime_id: + return handler.ActionResponse.error(message=f'Run claim {target_run_id} not found') + run = await store.release_claim( + run_id=str(target_run_id), + claim_token=str(claim_token), + status=str(data.get('status') or 'queued'), + status_reason=data.get('status_reason') or data.get('reason'), + ) + if run is None: + return handler.ActionResponse.error(message=f'Run claim {target_run_id} not found') + return handler.ActionResponse.success(data=run) + except Exception as e: + self.ap.logger.error(f'RUN_RELEASE_CLAIM error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Run release claim error: {e}') + @self.action(PluginToRuntimeAction.STEERING_PULL) async def steering_pull(data: dict[str, Any]) -> handler.ActionResponse: """Pull pending steering/follow-up inputs for the current run.""" @@ -1892,21 +2590,25 @@ class RuntimeConnectionHandler(handler.Handler): source='agent_runner', bot_id=conversation.get('bot_id') if isinstance(conversation, dict) else None, workspace_id=conversation.get('workspace_id') if isinstance(conversation, dict) else None, - conversation_id=conversation.get('conversation_id') if isinstance(conversation, dict) else None, + conversation_id=conversation.get('conversation_id') + if isinstance(conversation, dict) + else None, thread_id=conversation.get('thread_id') if isinstance(conversation, dict) else None, actor_type=actor.get('actor_type') if isinstance(actor, dict) else None, actor_id=actor.get('actor_id') if isinstance(actor, dict) else None, actor_name=actor.get('actor_name') if isinstance(actor, dict) else None, subject_type=subject.get('subject_type') if isinstance(subject, dict) else None, subject_id=subject.get('subject_id') if isinstance(subject, dict) else None, - input_summary=f"steering injected from {event.get('event_id')}", + input_summary=f'steering injected from {event.get("event_id")}', run_id=run_id, runner_id=session.get('runner_id') if isinstance(session, dict) else None, metadata={ 'steering': { 'status': 'injected', 'source_event_id': event.get('event_id'), - 'claimed_by_run_id': item.get('claimed_run_id') if isinstance(item, dict) else run_id, + 'claimed_by_run_id': item.get('claimed_run_id') + if isinstance(item, dict) + else run_id, 'claimed_runner_id': item.get('runner_id') if isinstance(item, dict) else None, 'claimed_at': item.get('claimed_at') if isinstance(item, dict) else None, 'pull_mode': str(mode or 'all'), @@ -1951,14 +2653,13 @@ class RuntimeConnectionHandler(handler.Handler): # Get artifact metadata from ..agent.runner.artifact_store import ArtifactStore + store = ArtifactStore(self.ap.persistence_mgr.get_db_engine()) try: metadata = await store.get_authorization_metadata(artifact_id) if not metadata: - return handler.ActionResponse.error( - message=f'Artifact {artifact_id} not found' - ) + return handler.ActionResponse.error(message=f'Artifact {artifact_id} not found') # Validate artifact access scope is_allowed, error_msg = _validate_artifact_access(session, metadata, 'metadata') @@ -2021,14 +2722,13 @@ class RuntimeConnectionHandler(handler.Handler): # Get artifact metadata first to validate access from ..agent.runner.artifact_store import ArtifactStore + store = ArtifactStore(self.ap.persistence_mgr.get_db_engine()) try: metadata = await store.get_authorization_metadata(artifact_id) if not metadata: - return handler.ActionResponse.error( - message=f'Artifact {artifact_id} not found' - ) + return handler.ActionResponse.error(message=f'Artifact {artifact_id} not found') # Validate artifact access scope is_allowed, error_msg = _validate_artifact_access(session, metadata, 'read') @@ -2043,9 +2743,7 @@ class RuntimeConnectionHandler(handler.Handler): ) if not result: - return handler.ActionResponse.error( - message=f'Failed to read artifact {artifact_id}' - ) + return handler.ActionResponse.error(message=f'Failed to read artifact {artifact_id}') return handler.ActionResponse.success(data=result) except ValueError as e: @@ -2093,6 +2791,7 @@ class RuntimeConnectionHandler(handler.Handler): # Get state from persistent store from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) try: @@ -2144,6 +2843,7 @@ class RuntimeConnectionHandler(handler.Handler): # Set state in persistent store from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) try: @@ -2202,6 +2902,7 @@ class RuntimeConnectionHandler(handler.Handler): # Delete state from persistent store from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) try: @@ -2250,14 +2951,17 @@ class RuntimeConnectionHandler(handler.Handler): # List state keys from persistent store from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) try: keys, has_more = await store.state_list(scope_key, prefix, limit) - return handler.ActionResponse.success(data={ - 'keys': keys, - 'has_more': has_more, - }) + return handler.ActionResponse.success( + data={ + 'keys': keys, + 'has_more': has_more, + } + ) except Exception as e: self.ap.logger.error(f'STATE_LIST error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'State list error: {e}') diff --git a/tests/unit_tests/agent/test_orchestrator_integration.py b/tests/unit_tests/agent/test_orchestrator_integration.py index 4cee5ba4..05fb2ecc 100644 --- a/tests/unit_tests/agent/test_orchestrator_integration.py +++ b/tests/unit_tests/agent/test_orchestrator_integration.py @@ -1,4 +1,5 @@ """Integration-style tests for AgentRunOrchestrator with a fake plugin runner.""" + from __future__ import annotations import asyncio @@ -15,6 +16,7 @@ from langbot.pkg.agent.runner.orchestrator import AgentRunOrchestrator from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter from langbot.pkg.agent.runner.binding_resolver import AgentBindingResolver from langbot.pkg.agent.runner.session_registry import get_session_registry +from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore from langbot.pkg.agent.runner.persistent_state_store import reset_persistent_state_store from langbot_plugin.api.entities.builtin.platform import entities as platform_entities from langbot_plugin.api.entities.builtin.platform import events as platform_events @@ -24,7 +26,7 @@ from langbot_plugin.api.entities.builtin.provider import session as provider_ses from langbot_plugin.api.entities.builtin.resource import tool as resource_tool -RUNNER_ID = "plugin:langbot/local-agent/default" +RUNNER_ID = 'plugin:langbot/local-agent/default' class FakeLogger: @@ -46,22 +48,22 @@ class FakeLogger: class FakeVersionManager: def get_current_version(self): - return "test-version" + return 'test-version' class FakeModel: - def __init__(self, model_type: str = "chat"): + def __init__(self, model_type: str = 'chat'): self.model_entity = types.SimpleNamespace(model_type=model_type) - self.provider_entity = types.SimpleNamespace(name="fake-provider") + self.provider_entity = types.SimpleNamespace(name='fake-provider') class FakeKnowledgeBase: def __init__(self, kb_id: str): self.kb_id = kb_id - self.knowledge_base_entity = types.SimpleNamespace(kb_type="fake") + self.knowledge_base_entity = types.SimpleNamespace(kb_type='fake') def get_name(self): - return f"KB {self.kb_id}" + return f'KB {self.kb_id}' class FakePluginConnector: @@ -78,13 +80,13 @@ class FakePluginConnector: async def run_agent(self, plugin_author, plugin_name, runner_name, context): self.calls.append( { - "plugin_author": plugin_author, - "plugin_name": plugin_name, - "runner_name": runner_name, + 'plugin_author': plugin_author, + 'plugin_name': plugin_name, + 'runner_name': runner_name, } ) self.contexts.append(context) - self.sessions_during_run.append(await get_session_registry().get(context["run_id"])) + self.sessions_during_run.append(await get_session_registry().get(context['run_id'])) if self.error: raise self.error @@ -101,7 +103,7 @@ class FakeRegistry: self.calls: list[dict] = [] async def get(self, runner_id, bound_plugins=None): - self.calls.append({"runner_id": runner_id, "bound_plugins": bound_plugins}) + self.calls.append({'runner_id': runner_id, 'bound_plugins': bound_plugins}) assert runner_id == self.descriptor.id return self.descriptor @@ -121,59 +123,57 @@ class FakeApplication: self.plugin_connector = plugin_connector self.persistence_mgr = FakePersistenceManager(db_engine) - self.model_mgr = types.SimpleNamespace( - get_model_by_uuid=AsyncMock(return_value=FakeModel()) - ) + self.model_mgr = types.SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=FakeModel())) self.rag_mgr = types.SimpleNamespace( - get_knowledge_base_by_uuid=AsyncMock(return_value=FakeKnowledgeBase("kb_001")) + get_knowledge_base_by_uuid=AsyncMock(return_value=FakeKnowledgeBase('kb_001')) ) self.skill_mgr = types.SimpleNamespace( skills={ - "demo": { - "name": "demo", - "display_name": "Demo Skill", - "description": "Helps with demo tasks.", + 'demo': { + 'name': 'demo', + 'display_name': 'Demo Skill', + 'description': 'Helps with demo tasks.', }, - "hidden": { - "name": "hidden", - "display_name": "Hidden Skill", - "description": "Not bound to this pipeline.", + 'hidden': { + 'name': 'hidden', + 'display_name': 'Hidden Skill', + 'description': 'Not bound to this pipeline.', }, } ) class FakeConversation: - uuid = "conv_existing" + uuid = 'conv_existing' create_time = datetime.datetime(2026, 5, 15, 12, 0, 0) def make_descriptor() -> AgentRunnerDescriptor: return AgentRunnerDescriptor( id=RUNNER_ID, - source="plugin", - label={"en_US": "Local Agent"}, - plugin_author="langbot", - plugin_name="local-agent", - runner_name="default", + source='plugin', + label={'en_US': 'Local Agent'}, + plugin_author='langbot', + plugin_name='local-agent', + runner_name='default', capabilities={ - "streaming": True, - "tool_calling": True, - "knowledge_retrieval": True, - "skill_authoring": True, + 'streaming': True, + 'tool_calling': True, + 'knowledge_retrieval': True, + 'skill_authoring': True, }, permissions={ - "models": ["invoke", "stream"], - "tools": ["detail", "call"], - "knowledge_bases": ["list", "retrieve"], - "history": ["page", "search"], - "events": ["get", "page"], - "artifacts": ["metadata", "read"], - "storage": ["plugin"], + 'models': ['invoke', 'stream'], + 'tools': ['detail', 'call'], + 'knowledge_bases': ['list', 'retrieve'], + 'history': ['page', 'search'], + 'events': ['get', 'page'], + 'artifacts': ['metadata', 'read'], + 'storage': ['plugin'], }, config_schema=[ - {"name": "model", "type": "model-fallback-selector"}, - {"name": "knowledge-bases", "type": "knowledge-base-multi-selector", "default": []}, + {'name': 'model', 'type': 'model-fallback-selector'}, + {'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []}, ], ) @@ -185,39 +185,39 @@ def make_query(): message_chain = platform_message.MessageChain( [ platform_message.Source( - id="msg_001", + id='msg_001', time=datetime.datetime(2026, 5, 15, 12, 0, 0), ), - platform_message.Plain(text="hello"), - platform_message.File(name="spec.txt", url="https://example.com/spec.txt"), + platform_message.Plain(text='hello'), + platform_message.File(name='spec.txt', url='https://example.com/spec.txt'), ] ) - sender = platform_entities.Friend(id="user_001", nickname="Alice", remark=None) + sender = platform_entities.Friend(id='user_001', nickname='Alice', remark=None) message_event = platform_events.FriendMessage(sender=sender, message_chain=message_chain, time=1_784_098_800.0) session = types.SimpleNamespace( launcher_type=provider_session.LauncherTypes.PERSON, - launcher_id="user_001", - sender_id="user_001", + launcher_id='user_001', + sender_id='user_001', using_conversation=FakeConversation(), ) return types.SimpleNamespace( query_id=1001, launcher_type=provider_session.LauncherTypes.PERSON, - launcher_id="user_001", - sender_id="user_001", + launcher_id='user_001', + sender_id='user_001', message_event=message_event, message_chain=message_chain, - bot_uuid="bot_001", - pipeline_uuid="pipeline_001", + bot_uuid='bot_001', + pipeline_uuid='pipeline_001', pipeline_config={ - "ai": { - "runner": {"id": RUNNER_ID}, - "runner_config": { + 'ai': { + 'runner': {'id': RUNNER_ID}, + 'runner_config': { RUNNER_ID: { - "model": {"primary": "model_primary", "fallbacks": ["model_fallback"]}, - "knowledge-bases": ["kb_001"], - "timeout": 30, + 'model': {'primary': 'model_primary', 'fallbacks': ['model_fallback']}, + 'knowledge-bases': ['kb_001'], + 'timeout': 30, }, }, }, @@ -225,25 +225,25 @@ def make_query(): session=session, messages=[], user_message=provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_text("hello"), - provider_message.ContentElement.from_file_url("https://example.com/spec.txt", "spec.txt"), + provider_message.ContentElement.from_text('hello'), + provider_message.ContentElement.from_file_url('https://example.com/spec.txt', 'spec.txt'), ], ), variables={ - "_pipeline_bound_plugins": ["langbot/local-agent"], - "_fallback_model_uuids": ["model_fallback"], - "_pipeline_bound_skills": ["demo"], - "public_param": "visible", + '_pipeline_bound_plugins': ['langbot/local-agent'], + '_fallback_model_uuids': ['model_fallback'], + '_pipeline_bound_skills': ['demo'], + 'public_param': 'visible', }, - use_llm_model_uuid="model_primary", + use_llm_model_uuid='model_primary', use_funcs=[ resource_tool.LLMTool( - name="langbot/test-tool/search", - human_desc="Search", - description="Search test data", - parameters={"type": "object", "properties": {"q": {"type": "string"}}}, + name='langbot/test-tool/search', + human_desc='Search', + description='Search test data', + parameters={'type': 'object', 'properties': {'q': {'type': 'string'}}}, func=fake_func, ) ], @@ -253,57 +253,57 @@ def make_query(): def test_context_builder_includes_consumable_base64_attachments(): query = make_query() query.user_message = provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_text("see attached"), - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), - provider_message.ContentElement.from_file_base64("data:text/plain;base64,aGVsbG8=", "hello.txt"), + provider_message.ContentElement.from_text('see attached'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), + provider_message.ContentElement.from_file_base64('data:text/plain;base64,aGVsbG8=', 'hello.txt'), ], ) query.message_chain = platform_message.MessageChain( - [platform_message.Image(base64="data:image/jpeg;base64,aGVsbG8=")] + [platform_message.Image(base64='data:image/jpeg;base64,aGVsbG8=')] ) input_data = QueryEntryAdapter._build_input(query) - assert input_data.contents[0].text == "see attached" - assert input_data.contents[1].image_base64 == "data:image/png;base64,aGVsbG8=" - assert input_data.contents[2].file_base64 == "data:text/plain;base64,aGVsbG8=" + assert input_data.contents[0].text == 'see attached' + assert input_data.contents[1].image_base64 == 'data:image/png;base64,aGVsbG8=' + assert input_data.contents[2].file_base64 == 'data:text/plain;base64,aGVsbG8=' artifact_types = [attachment.artifact_type for attachment in input_data.attachments] - assert artifact_types == ["image", "file", "image"] - assert input_data.attachments[1].name == "hello.txt" + assert artifact_types == ['image', 'file', 'image'] + assert input_data.attachments[1].name == 'hello.txt' def test_context_builder_deduplicates_message_chain_attachments(): query = make_query() query.user_message = None query.message_chain = platform_message.MessageChain( - [platform_message.Image(base64="data:image/jpeg;base64,aGVsbG8=")] + [platform_message.Image(base64='data:image/jpeg;base64,aGVsbG8=')] ) input_data = QueryEntryAdapter._build_input(query) - assert [content.type for content in input_data.contents] == ["image_base64"] + assert [content.type for content in input_data.contents] == ['image_base64'] assert len(input_data.attachments) == 1 - assert input_data.attachments[0].artifact_type == "image" - assert input_data.attachments[0].content == "data:image/jpeg;base64,aGVsbG8=" + assert input_data.attachments[0].artifact_type == 'image' + assert input_data.attachments[0].content == 'data:image/jpeg;base64,aGVsbG8=' def test_context_builder_preserves_same_source_duplicate_attachments(): query = make_query() query.user_message = provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), ], ) query.message_chain = platform_message.MessageChain([]) input_data = QueryEntryAdapter._build_input(query) - assert [attachment.artifact_type for attachment in input_data.attachments] == ["image", "image"] + assert [attachment.artifact_type for attachment in input_data.attachments] == ['image', 'image'] @pytest.fixture(autouse=True) @@ -314,10 +314,10 @@ async def clean_agent_state(): reset_persistent_state_store() registry = get_session_registry() for session in await registry.list_active_runs(): - await registry.unregister(session["run_id"]) + await registry.unregister(session['run_id']) # Create in-memory SQLite engine for tests - test_engine = create_async_engine("sqlite+aiosqlite:///:memory:") + test_engine = create_async_engine('sqlite+aiosqlite:///:memory:') # Create tables async with test_engine.begin() as conn: @@ -327,7 +327,7 @@ async def clean_agent_state(): # Cleanup for session in await registry.list_active_runs(): - await registry.unregister(session["run_id"]) + await registry.unregister(session['run_id']) reset_persistent_state_store() await test_engine.dispose() @@ -340,8 +340,8 @@ async def test_orchestrator_runs_fake_plugin_with_authorized_context(clean_agent plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "fake response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'fake response'}}, } ] ) @@ -352,47 +352,152 @@ async def test_orchestrator_runs_fake_plugin_with_authorized_context(clean_agent messages = [message async for message in orchestrator.run_from_query(query)] assert len(messages) == 1 - assert messages[0].content == "fake response" + assert messages[0].content == 'fake response' assert plugin_connector.calls == [ { - "plugin_author": "langbot", - "plugin_name": "local-agent", - "runner_name": "default", + 'plugin_author': 'langbot', + 'plugin_name': 'local-agent', + 'runner_name': 'default', } ] context = plugin_connector.contexts[0] - assert context["config"]["timeout"] == 30 - assert context["runtime"]["deadline_at"] is not None + assert context['config']['timeout'] == 30 + assert context['runtime']['deadline_at'] is not None # Protocol v1: params is in adapter.extra - assert context["adapter"]["extra"]["params"] == {"public_param": "visible"} - assert context["event"]["event_type"] == "message.received" + assert context['adapter']['extra']['params'] == {'public_param': 'visible'} + assert context['event']['event_type'] == 'message.received' # Note: source_event_type is in event.source_event_type, not event.data # (event.data contains the raw event payload, not metadata) - assert context["actor"]["actor_id"] == "user_001" - assert context["actor"]["actor_name"] == "Alice" - assert context["subject"]["subject_id"] == "msg_001" - assert context["input"]["attachments"] + assert context['actor']['actor_id'] == 'user_001' + assert context['actor']['actor_name'] == 'Alice' + assert context['subject']['subject_id'] == 'msg_001' + assert context['input']['attachments'] + assert context['context']['available_apis']['run_get'] is True + assert context['context']['available_apis']['run_list'] is True + assert context['context']['available_apis']['run_events_page'] is True + assert context['context']['available_apis']['run_cancel'] is True + assert context['context']['available_apis']['run_append_result'] is False + assert context['context']['available_apis']['run_finalize'] is False + assert context['context']['available_apis']['run_claim'] is False + assert context['context']['available_apis']['run_renew_claim'] is False + assert context['context']['available_apis']['run_release_claim'] is False + assert context['context']['available_apis']['runtime_register'] is False + assert context['context']['available_apis']['runtime_heartbeat'] is False + assert context['context']['available_apis']['runtime_list'] is False - resources = context["resources"] - assert {m["model_id"] for m in resources["models"]} == {"model_primary", "model_fallback"} - assert resources["tools"][0]["tool_name"] == "langbot/test-tool/search" - assert resources["knowledge_bases"][0]["kb_id"] == "kb_001" - assert resources["skills"] == [ + resources = context['resources'] + assert {m['model_id'] for m in resources['models']} == {'model_primary', 'model_fallback'} + assert resources['tools'][0]['tool_name'] == 'langbot/test-tool/search' + assert resources['knowledge_bases'][0]['kb_id'] == 'kb_001' + assert resources['skills'] == [ { - "skill_name": "demo", - "display_name": "Demo Skill", - "description": "Helps with demo tasks.", + 'skill_name': 'demo', + 'display_name': 'Demo Skill', + 'description': 'Helps with demo tasks.', } ] - assert resources["storage"]["plugin_storage"] is True + assert resources['storage']['plugin_storage'] is True session_during_run = plugin_connector.sessions_during_run[0] assert session_during_run is not None - assert session_during_run["plugin_identity"] == "langbot/local-agent" - assert session_during_run["authorization"]["authorized_ids"]["tool"] == {"langbot/test-tool/search"} - assert session_during_run["authorization"]["authorized_ids"]["skill"] == {"demo"} - assert await get_session_registry().get(context["run_id"]) is None + assert session_during_run['plugin_identity'] == 'langbot/local-agent' + assert session_during_run['authorization']['authorized_ids']['tool'] == {'langbot/test-tool/search'} + assert session_during_run['authorization']['authorized_ids']['skill'] == {'demo'} + assert await get_session_registry().get(context['run_id']) is None + + +@pytest.mark.asyncio +async def test_orchestrator_persists_run_ledger(clean_agent_state): + """AgentRunOrchestrator records Host-owned run and result events.""" + db_engine = clean_agent_state + descriptor = make_descriptor() + plugin_connector = FakePluginConnector( + results=[ + { + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'fake response'}}, + }, + { + 'type': 'run.completed', + 'data': {'finish_reason': 'stop'}, + 'usage': {'prompt_tokens': 2, 'completion_tokens': 3, 'total_tokens': 5}, + }, + ] + ) + ap = FakeApplication(plugin_connector, db_engine) + orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) + + messages = [message async for message in orchestrator.run_from_query(make_query())] + + assert len(messages) == 1 + run_id = plugin_connector.contexts[0]['run_id'] + store = RunLedgerStore(db_engine) + + run = await store.get_run(run_id) + assert run is not None + assert run['status'] == 'completed' + assert run['event_id'] == plugin_connector.contexts[0]['event']['event_id'] + assert run['runner_id'] == RUNNER_ID + assert run['usage'] == { + 'prompt_tokens': 2, + 'completion_tokens': 3, + 'total_tokens': 5, + } + + events, next_cursor, prev_cursor, has_more = await store.page_run_events( + run_id=run_id, + limit=10, + ) + assert [event['sequence'] for event in events] == [1, 2] + assert [event['type'] for event in events] == ['message.completed', 'run.completed'] + assert next_cursor is None + assert prev_cursor == 1 + assert has_more is False + + +@pytest.mark.asyncio +async def test_orchestrator_stops_after_cancel_request(clean_agent_state): + """A persisted cancel request stops further synchronous runner consumption.""" + db_engine = clean_agent_state + descriptor = make_descriptor() + plugin_connector = FakePluginConnector( + results=[ + { + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'first'}}, + }, + { + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'second'}}, + }, + ] + ) + orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor)) + original_append_run_result = orchestrator.journal.append_run_result + cancel_requested = False + + async def append_and_cancel_once(*args, **kwargs): + nonlocal cancel_requested + event = await original_append_run_result(*args, **kwargs) + if not cancel_requested: + cancel_requested = True + await RunLedgerStore(db_engine).request_cancel( + run_id=kwargs['run_id'], + status_reason='user stopped', + ) + return event + + orchestrator.journal.append_run_result = append_and_cancel_once + + messages = [message async for message in orchestrator.run_from_query(make_query())] + + assert [message.content for message in messages] == ['first'] + run_id = plugin_connector.contexts[0]['run_id'] + run = await RunLedgerStore(db_engine).get_run(run_id) + assert run is not None + assert run['status'] == 'cancelled' + assert run['status_reason'] == 'user stopped' @pytest.mark.asyncio @@ -403,39 +508,39 @@ async def test_orchestrator_does_not_package_query_messages_into_context(clean_a plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "fake response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'fake response'}}, } ] ) ap = FakeApplication(plugin_connector, db_engine) orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() - query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["custom-option"] = 2 + query.pipeline_config['ai']['runner_config'][RUNNER_ID]['custom-option'] = 2 query.messages = [ - provider_message.Message(role="user", content="message 1"), - provider_message.Message(role="assistant", content="response 1"), - provider_message.Message(role="user", content="message 2"), - provider_message.Message(role="assistant", content="response 2"), - provider_message.Message(role="user", content="message 3"), - provider_message.Message(role="assistant", content="response 3"), + provider_message.Message(role='user', content='message 1'), + provider_message.Message(role='assistant', content='response 1'), + provider_message.Message(role='user', content='message 2'), + provider_message.Message(role='assistant', content='response 2'), + provider_message.Message(role='user', content='message 3'), + provider_message.Message(role='assistant', content='response 3'), ] messages = [message async for message in orchestrator.run_from_query(query)] assert len(messages) == 1 context = plugin_connector.contexts[0] - assert context["config"]["custom-option"] == 2 - assert "bootstrap" not in context - assert set(context["adapter"]) == {"extra"} - assert "context_packaging" not in context["runtime"]["metadata"] + assert context['config']['custom-option'] == 2 + assert 'bootstrap' not in context + assert set(context['adapter']) == {'extra'} + assert 'context_packaging' not in context['runtime']['metadata'] assert [message.content for message in query.messages] == [ - "message 1", - "response 1", - "message 2", - "response 2", - "message 3", - "response 3", + 'message 1', + 'response 1', + 'message 2', + 'response 2', + 'message 3', + 'response 3', ] @@ -446,16 +551,16 @@ async def test_orchestrator_streams_fake_plugin_deltas(clean_agent_state): descriptor = make_descriptor() plugin_connector = FakePluginConnector( results=[ - {"type": "message.delta", "data": {"chunk": {"role": "assistant", "content": "hel"}}}, - {"type": "message.delta", "data": {"chunk": {"role": "assistant", "content": "hello"}}}, - {"type": "run.completed", "data": {"finish_reason": "stop"}}, + {'type': 'message.delta', 'data': {'chunk': {'role': 'assistant', 'content': 'hel'}}}, + {'type': 'message.delta', 'data': {'chunk': {'role': 'assistant', 'content': 'hello'}}}, + {'type': 'run.completed', 'data': {'finish_reason': 'stop'}}, ] ) orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor)) chunks = [message async for message in orchestrator.run_from_query(make_query())] - assert [chunk.content for chunk in chunks] == ["hel", "hello"] + assert [chunk.content for chunk in chunks] == ['hel', 'hello'] @pytest.mark.asyncio @@ -468,10 +573,10 @@ async def test_orchestrator_persists_run_completed_message_transcript(clean_agen plugin_connector = FakePluginConnector( results=[ { - "type": "run.completed", - "data": { - "finish_reason": "stop", - "message": {"role": "assistant", "content": "final response"}, + 'type': 'run.completed', + 'data': { + 'finish_reason': 'stop', + 'message': {'role': 'assistant', 'content': 'final response'}, }, }, ] @@ -481,12 +586,12 @@ async def test_orchestrator_persists_run_completed_message_transcript(clean_agen messages = [message async for message in orchestrator.run_from_query(query)] - assert [message.content for message in messages] == ["final response"] + assert [message.content for message in messages] == ['final response'] transcript_store = TranscriptStore(db_engine) transcripts, _, _, _ = await transcript_store.page_transcript(query.session.using_conversation.uuid, limit=10) - assistant_items = [item for item in transcripts if item["role"] == "assistant"] + assistant_items = [item for item in transcripts if item['role'] == 'assistant'] assert len(assistant_items) == 1 - assert assistant_items[0]["content"] == "final response" + assert assistant_items[0]['content'] == 'final response' @pytest.mark.asyncio @@ -497,21 +602,21 @@ async def test_orchestrator_drops_duplicate_result_sequence(clean_agent_state): plugin_connector = FakePluginConnector( results=[ { - "type": "message.delta", - "sequence": 1, - "data": {"chunk": {"role": "assistant", "content": "first"}}, + 'type': 'message.delta', + 'sequence': 1, + 'data': {'chunk': {'role': 'assistant', 'content': 'first'}}, }, { - "type": "message.delta", - "sequence": 1, - "data": {"chunk": {"role": "assistant", "content": "duplicate"}}, + 'type': 'message.delta', + 'sequence': 1, + 'data': {'chunk': {'role': 'assistant', 'content': 'duplicate'}}, }, { - "type": "message.delta", - "sequence": 3, - "data": {"chunk": {"role": "assistant", "content": "after-gap"}}, + 'type': 'message.delta', + 'sequence': 3, + 'data': {'chunk': {'role': 'assistant', 'content': 'after-gap'}}, }, - {"type": "run.completed", "sequence": 4, "data": {"finish_reason": "stop"}}, + {'type': 'run.completed', 'sequence': 4, 'data': {'finish_reason': 'stop'}}, ] ) ap = FakeApplication(plugin_connector, db_engine) @@ -519,9 +624,9 @@ async def test_orchestrator_drops_duplicate_result_sequence(clean_agent_state): chunks = [message async for message in orchestrator.run_from_query(make_query())] - assert [chunk.content for chunk in chunks] == ["first", "after-gap"] - assert any("duplicate result sequence 1" in warning for warning in ap.logger.warnings) - assert any("result sequence gap or out-of-order" in warning for warning in ap.logger.warnings) + assert [chunk.content for chunk in chunks] == ['first', 'after-gap'] + assert any('duplicate result sequence 1' in warning for warning in ap.logger.warnings) + assert any('result sequence gap or out-of-order' in warning for warning in ap.logger.warnings) @pytest.mark.asyncio @@ -532,16 +637,16 @@ async def test_orchestrator_applies_state_updates_and_suppresses_protocol_event( plugin_connector = FakePluginConnector( results=[ { - "type": "state.updated", - "data": { - "scope": "conversation", - "key": "external.conversation_id", - "value": "external_conv_123", + 'type': 'state.updated', + 'data': { + 'scope': 'conversation', + 'key': 'external.conversation_id', + 'value': 'external_conv_123', }, }, { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "state saved"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'state saved'}}, }, ] ) @@ -550,7 +655,7 @@ async def test_orchestrator_applies_state_updates_and_suppresses_protocol_event( messages = [message async for message in orchestrator.run_from_query(query)] - assert [message.content for message in messages] == ["state saved"] + assert [message.content for message in messages] == ['state saved'] # State is persisted to the database via PersistentStateStore. @@ -562,8 +667,8 @@ async def test_orchestrator_unregisters_session_after_runner_failure(clean_agent plugin_connector = FakePluginConnector( results=[ { - "type": "run.failed", - "data": {"error": "boom", "code": "fake.error", "retryable": False}, + 'type': 'run.failed', + 'data': {'error': 'boom', 'code': 'fake.error', 'retryable': False}, } ] ) @@ -574,7 +679,7 @@ async def test_orchestrator_unregisters_session_after_runner_failure(clean_agent context = plugin_connector.contexts[0] assert plugin_connector.sessions_during_run[0] is not None - assert await get_session_registry().get(context["run_id"]) is None + assert await get_session_registry().get(context['run_id']) is None @pytest.mark.asyncio @@ -585,15 +690,15 @@ async def test_orchestrator_unregisters_session_after_event_log_failure(clean_ag plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "unused"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'unused'}}, } ] ) orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor)) - orchestrator.journal.write_event_log = AsyncMock(side_effect=RuntimeError("journal unavailable")) + orchestrator.journal.write_event_log = AsyncMock(side_effect=RuntimeError('journal unavailable')) - with pytest.raises(RuntimeError, match="journal unavailable"): + with pytest.raises(RuntimeError, match='journal unavailable'): [message async for message in orchestrator.run_from_query(make_query())] assert plugin_connector.contexts == [] @@ -608,21 +713,21 @@ async def test_orchestrator_enforces_total_runner_deadline(clean_agent_state): plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "too late"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'too late'}}, } ], delay=0.05, ) orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor)) query = make_query() - query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["timeout"] = 0.01 + query.pipeline_config['ai']['runner_config'][RUNNER_ID]['timeout'] = 0.01 with pytest.raises(RunnerExecutionError) as exc_info: [message async for message in orchestrator.run_from_query(query)] assert exc_info.value.retryable is True - assert "runner.timeout" in str(exc_info.value) + assert 'runner.timeout' in str(exc_info.value) assert await get_session_registry().list_active_runs() == [] @@ -637,8 +742,8 @@ class TestQueryEntrySessionQueryId: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -646,11 +751,11 @@ class TestQueryEntrySessionQueryId: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.user_message = provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_text("hello"), - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), - provider_message.ContentElement.from_file_base64("data:text/plain;base64,aGVsbG8=", "hello.txt"), + provider_message.ContentElement.from_text('hello'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), + provider_message.ContentElement.from_file_base64('data:text/plain;base64,aGVsbG8=', 'hello.txt'), ], ) @@ -660,12 +765,19 @@ class TestQueryEntrySessionQueryId: # Verify session during run had query_id session_during_run = plugin_connector.sessions_during_run[0] assert session_during_run is not None - assert session_during_run["query_id"] == query.query_id + assert session_during_run['query_id'] == query.query_id @pytest.mark.asyncio async def test_no_query_id_for_pure_event_first_flow(self, clean_agent_state): """Pure event-first flow has query_id=None in session.""" - from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy, DeliveryPolicy, ResourcePolicy + from langbot.pkg.agent.runner.host_models import ( + AgentEventEnvelope, + AgentBinding, + BindingScope, + StatePolicy, + DeliveryPolicy, + ResourcePolicy, + ) from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext @@ -674,8 +786,8 @@ class TestQueryEntrySessionQueryId: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -684,23 +796,23 @@ class TestQueryEntrySessionQueryId: # Create event and binding directly (not from Query) event = AgentEventEnvelope( - event_id="evt_001", - event_type="message.received", + event_id='evt_001', + event_type='message.received', event_time=1234567890, - source="test", - bot_id="bot_001", + source='test', + bot_id='bot_001', workspace_id=None, - conversation_id="conv_001", + conversation_id='conv_001', thread_id=None, actor=None, subject=None, - input=AgentInput(text="hello", contents=[], attachments=[]), - delivery=DeliveryContext(surface="test", supports_streaming=True), + input=AgentInput(text='hello', contents=[], attachments=[]), + delivery=DeliveryContext(surface='test', supports_streaming=True), ) binding = AgentBinding( - binding_id="binding_001", - scope=BindingScope(scope_type="agent", scope_id="pipeline_001"), - event_types=["message.received"], + binding_id='binding_001', + scope=BindingScope(scope_type='agent', scope_id='pipeline_001'), + event_types=['message.received'], runner_id=RUNNER_ID, runner_config={}, resource_policy=ResourcePolicy(), @@ -715,7 +827,7 @@ class TestQueryEntrySessionQueryId: # Verify session during run has query_id=None session_during_run = plugin_connector.sessions_during_run[0] assert session_during_run is not None - assert session_during_run["query_id"] is None + assert session_during_run['query_id'] is None class TestQueryEntryAdapterParams: @@ -731,8 +843,8 @@ class TestQueryEntryAdapterParams: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -742,18 +854,18 @@ class TestQueryEntryAdapterParams: # Add prompt to query query.prompt = provider_prompt.Prompt( - name="test_prompt", + name='test_prompt', messages=[ - provider_message.Message(role="system", content="You are a helpful assistant."), + provider_message.Message(role='system', content='You are a helpful assistant.'), ], ) _messages = [message async for message in orchestrator.run_from_query(query)] context = plugin_connector.contexts[0] - assert "prompt" not in context - assert "prompt" not in context["adapter"]["extra"] - assert context["context"]["available_apis"]["prompt_get"] is True + assert 'prompt' not in context + assert 'prompt' not in context['adapter']['extra'] + assert context['context']['available_apis']['prompt_get'] is True @pytest.mark.asyncio async def test_params_filtering_keeps_public_param(self, clean_agent_state): @@ -763,8 +875,8 @@ class TestQueryEntryAdapterParams: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -772,16 +884,16 @@ class TestQueryEntryAdapterParams: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.variables = { - "public_param": "visible", - "another_param": 123, + 'public_param': 'visible', + 'another_param': 123, } _messages = [message async for message in orchestrator.run_from_query(query)] context = plugin_connector.contexts[0] - assert context["adapter"]["extra"]["params"] == { - "public_param": "visible", - "another_param": 123, + assert context['adapter']['extra']['params'] == { + 'public_param': 'visible', + 'another_param': 123, } @pytest.mark.asyncio @@ -792,8 +904,8 @@ class TestQueryEntryAdapterParams: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -801,18 +913,18 @@ class TestQueryEntryAdapterParams: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.variables = { - "public_param": "visible", - "_internal_var": "should_be_filtered", - "_pipeline_bound_plugins": ["plugin1"], + 'public_param': 'visible', + '_internal_var': 'should_be_filtered', + '_pipeline_bound_plugins': ['plugin1'], } _messages = [message async for message in orchestrator.run_from_query(query)] context = plugin_connector.contexts[0] - params = context["adapter"]["extra"]["params"] - assert "public_param" in params - assert "_internal_var" not in params - assert "_pipeline_bound_plugins" not in params + params = context['adapter']['extra']['params'] + assert 'public_param' in params + assert '_internal_var' not in params + assert '_pipeline_bound_plugins' not in params @pytest.mark.asyncio async def test_params_filtering_removes_sensitive_patterns(self, clean_agent_state): @@ -822,8 +934,8 @@ class TestQueryEntryAdapterParams: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -831,22 +943,22 @@ class TestQueryEntryAdapterParams: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.variables = { - "public_param": "visible", - "api_token": "secret123", - "secret_key": "secret456", - "password": "secret789", - "credential": "secret000", + 'public_param': 'visible', + 'api_token': 'secret123', + 'secret_key': 'secret456', + 'password': 'secret789', + 'credential': 'secret000', } _messages = [message async for message in orchestrator.run_from_query(query)] context = plugin_connector.contexts[0] - params = context["adapter"]["extra"]["params"] - assert "public_param" in params - assert "api_token" not in params - assert "secret_key" not in params - assert "password" not in params - assert "credential" not in params + params = context['adapter']['extra']['params'] + assert 'public_param' in params + assert 'api_token' not in params + assert 'secret_key' not in params + assert 'password' not in params + assert 'credential' not in params @pytest.mark.asyncio async def test_params_filtering_removes_non_json_serializable(self, clean_agent_state): @@ -856,8 +968,8 @@ class TestQueryEntryAdapterParams: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'response'}}, } ] ) @@ -865,18 +977,18 @@ class TestQueryEntryAdapterParams: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.variables = { - "public_param": "visible", - "a_set": {1, 2, 3}, # set is not JSON-serializable - "a_lambda": lambda x: x, # function is not JSON-serializable + 'public_param': 'visible', + 'a_set': {1, 2, 3}, # set is not JSON-serializable + 'a_lambda': lambda x: x, # function is not JSON-serializable } _messages = [message async for message in orchestrator.run_from_query(query)] context = plugin_connector.contexts[0] - params = context["adapter"]["extra"]["params"] - assert "public_param" in params - assert "a_set" not in params - assert "a_lambda" not in params + params = context['adapter']['extra']['params'] + assert 'public_param' in params + assert 'a_set' not in params + assert 'a_lambda' not in params class TestQueryEntryAdapterHostCapabilities: @@ -892,16 +1004,16 @@ class TestQueryEntryAdapterHostCapabilities: plugin_connector = FakePluginConnector( results=[ { - "type": "state.updated", - "data": { - "scope": "conversation", - "key": "external.test_key", - "value": "test_value", + 'type': 'state.updated', + 'data': { + 'scope': 'conversation', + 'key': 'external.test_key', + 'value': 'test_value', }, }, { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "state saved"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'state saved'}}, }, ] ) @@ -909,29 +1021,30 @@ class TestQueryEntryAdapterHostCapabilities: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.user_message = provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_text("hello"), - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), + provider_message.ContentElement.from_text('hello'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), ], ) messages = [message async for message in orchestrator.run_from_query(query)] assert len(messages) == 1 - assert messages[0].content == "state saved" + assert messages[0].content == 'state saved' # Verify state was written to PersistentStateStore persistent_store = get_persistent_state_store(db_engine) # Build snapshot to check if state was written # Note: We need to rebuild the event and binding to query the store from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter + event = QueryEntryAdapter.query_to_event(query) agent_config = QueryEntryAdapter.config_to_agent_config(query, RUNNER_ID) binding = AgentBindingResolver().resolve_one(event, [agent_config]) snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) - assert snapshot["conversation"]["external.test_key"] == "test_value" + assert snapshot['conversation']['external.test_key'] == 'test_value' @pytest.mark.asyncio async def test_run_from_query_restores_activated_skills_from_state(self, clean_agent_state): @@ -947,8 +1060,8 @@ class TestQueryEntryAdapterHostCapabilities: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "restored"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'restored'}}, } ] ) @@ -964,9 +1077,9 @@ class TestQueryEntryAdapterHostCapabilities: event, binding, descriptor, - "conversation", + 'conversation', ACTIVATED_SKILL_NAMES_STATE_KEY, - ["demo"], + ['demo'], None, ) assert success is True @@ -975,7 +1088,7 @@ class TestQueryEntryAdapterHostCapabilities: messages = [message async for message in orchestrator.run_from_query(query)] assert len(messages) == 1 - assert query.variables[ACTIVATED_SKILLS_KEY]["demo"]["name"] == "demo" + assert query.variables[ACTIVATED_SKILLS_KEY]['demo']['name'] == 'demo' @pytest.mark.asyncio async def test_event_log_and_transcript_written(self, clean_agent_state): @@ -988,8 +1101,8 @@ class TestQueryEntryAdapterHostCapabilities: plugin_connector = FakePluginConnector( results=[ { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "assistant response"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'assistant response'}}, }, ] ) @@ -997,10 +1110,10 @@ class TestQueryEntryAdapterHostCapabilities: orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor)) query = make_query() query.user_message = provider_message.Message( - role="user", + role='user', content=[ - provider_message.ContentElement.from_text("hello"), - provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="), + provider_message.ContentElement.from_text('hello'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,aGVsbG8='), ], ) @@ -1016,10 +1129,10 @@ class TestQueryEntryAdapterHostCapabilities: ) assert len(event_logs) >= 1 # First event should be the incoming message.received - assert event_logs[0]["event_type"] == "message.received" - assert event_logs[0]["input_json"]["contents"][1]["image_base64"] is None - assert event_logs[0]["input_json"]["contents"][1]["content_redacted"] is True - assert "aGVsbG8=" not in str(event_logs[0]["input_json"]) + assert event_logs[0]['event_type'] == 'message.received' + assert event_logs[0]['input_json']['contents'][1]['image_base64'] is None + assert event_logs[0]['input_json']['contents'][1]['content_redacted'] is True + assert 'aGVsbG8=' not in str(event_logs[0]['input_json']) # Check Transcript has user and assistant messages transcript_store = TranscriptStore(db_engine) @@ -1030,13 +1143,13 @@ class TestQueryEntryAdapterHostCapabilities: ) assert len(transcripts) >= 2 # Find user and assistant messages - roles = [t["role"] for t in transcripts] - assert "user" in roles - assert "assistant" in roles - user_item = next(t for t in transcripts if t["role"] == "user") - assert user_item["content_json"]["content"][1]["image_base64"] is None - assert user_item["artifact_refs"][0]["content"] is None - assert "aGVsbG8=" not in str(user_item) + roles = [t['role'] for t in transcripts] + assert 'user' in roles + assert 'assistant' in roles + user_item = next(t for t in transcripts if t['role'] == 'user') + assert user_item['content_json']['content'][1]['image_base64'] is None + assert user_item['artifact_refs'][0]['content'] is None + assert 'aGVsbG8=' not in str(user_item) @pytest.mark.asyncio async def test_artifact_created_via_event_first_path(self, clean_agent_state): @@ -1047,24 +1160,24 @@ class TestQueryEntryAdapterHostCapabilities: db_engine = clean_agent_state descriptor = make_descriptor() - artifact_id = "artifact_001" - content = b"test artifact content" + artifact_id = 'artifact_001' + content = b'test artifact content' content_base64 = base64.b64encode(content).decode('utf-8') plugin_connector = FakePluginConnector( results=[ { - "type": "artifact.created", - "data": { - "artifact_id": artifact_id, - "artifact_type": "file", - "mime_type": "text/plain", - "name": "test.txt", - "content_base64": content_base64, + 'type': 'artifact.created', + 'data': { + 'artifact_id': artifact_id, + 'artifact_type': 'file', + 'mime_type': 'text/plain', + 'name': 'test.txt', + 'content_base64': content_base64, }, }, { - "type": "message.completed", - "data": {"message": {"role": "assistant", "content": "artifact created"}}, + 'type': 'message.completed', + 'data': {'message': {'role': 'assistant', 'content': 'artifact created'}}, }, ] ) @@ -1075,14 +1188,14 @@ class TestQueryEntryAdapterHostCapabilities: messages = [message async for message in orchestrator.run_from_query(query)] assert len(messages) == 1 - assert messages[0].content == "artifact created" + assert messages[0].content == 'artifact created' # Verify artifact was registered in ArtifactStore artifact_store = ArtifactStore(db_engine) artifact = await artifact_store.get_metadata(artifact_id) assert artifact is not None - assert artifact["artifact_type"] == "file" - assert artifact["name"] == "test.txt" + assert artifact['artifact_type'] == 'file' + assert artifact['name'] == 'test.txt' # Verify artifact.created event was written to EventLog event_log_store = EventLogStore(db_engine) @@ -1090,5 +1203,5 @@ class TestQueryEntryAdapterHostCapabilities: conversation_id=query.session.using_conversation.uuid, limit=10, ) - artifact_events = [e for e in event_logs if e["event_type"] == "artifact.created"] + artifact_events = [e for e in event_logs if e['event_type'] == 'artifact.created'] assert len(artifact_events) >= 1 diff --git a/tests/unit_tests/agent/test_run_ledger_api_auth.py b/tests/unit_tests/agent/test_run_ledger_api_auth.py new file mode 100644 index 00000000..ebe42767 --- /dev/null +++ b/tests/unit_tests/agent/test_run_ledger_api_auth.py @@ -0,0 +1,490 @@ +"""Tests for AgentRunner run ledger pull API authorization.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.ext.asyncio import create_async_engine + +from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore +from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry +from langbot.pkg.entity.persistence import agent_run as agent_run_model +from langbot.pkg.entity.persistence.base import Base +from langbot.pkg.plugin.handler import RuntimeConnectionHandler +from langbot_plugin.api.entities.builtin.agent_runner.run_ledger import ( + AgentRun, + AgentRunEvent, + RunEventPage, + RunPage, +) +from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction + +from .conftest import make_resources + + +class FakeConnection: + pass + + +class FakeApplication: + def __init__(self, db_engine): + self.logger = MagicMock() + self.persistence_mgr = MagicMock() + self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + +@pytest.fixture +def session_registry(monkeypatch): + registry = AgentRunSessionRegistry() + monkeypatch.setattr( + 'langbot.pkg.agent.runner.session_registry._global_registry', + registry, + ) + return registry + + +@pytest.fixture +async def db_engine(): + engine = create_async_engine('sqlite+aiosqlite:///:memory:') + assert agent_run_model.AgentRun.__tablename__ == 'agent_run' + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +def _handler(db_engine): + async def fake_disconnect(): + return True + + fake_app = FakeApplication(db_engine) + return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + + +async def _register_session( + session_registry, + *, + run_id='run_1', + conversation_id='conv_1', + available_apis=None, +): + await session_registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=None, + plugin_identity='test/runner', + resources=make_resources(), + conversation_id=conversation_id, + bot_id='bot_1', + workspace_id='workspace_1', + thread_id=None, + available_apis=available_apis or {}, + ) + + +async def _create_run( + db_engine, + *, + run_id='run_1', + conversation_id='conv_1', + bot_id='bot_1', + workspace_id='workspace_1', + thread_id=None, + plugin_identity='test/runner', + available_apis=None, +): + store = RunLedgerStore(db_engine) + await store.create_run( + run_id=run_id, + event_id='evt_1', + binding_id='binding_1', + runner_id='plugin:test/runner/default', + conversation_id=conversation_id, + bot_id=bot_id, + workspace_id=workspace_id, + thread_id=thread_id, + authorization={ + 'runner_id': 'plugin:test/runner/default', + 'binding_id': 'binding_1', + 'plugin_identity': plugin_identity, + 'resources': make_resources(), + 'available_apis': available_apis or {}, + 'conversation_id': conversation_id, + 'bot_id': bot_id, + 'workspace_id': workspace_id, + 'thread_id': thread_id, + 'state_policy': {'enable_state': True, 'state_scopes': ['conversation', 'actor']}, + 'state_context': {}, + }, + ) + await store.append_event( + run_id=run_id, + sequence=1, + event_type='message.completed', + data={'message': {'role': 'assistant', 'content': 'ok'}}, + ) + + +@pytest.mark.asyncio +async def test_run_get_returns_current_run(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_get': True}) + await _create_run(db_engine) + handler = _handler(db_engine) + run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value] + + result = await run_get( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code == 0 + run = AgentRun.model_validate(result.data) + assert run.run_id == 'run_1' + assert run.status == 'running' + + +@pytest.mark.asyncio +async def test_run_list_rejects_cross_conversation(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_list': True}) + handler = _handler(db_engine) + run_list = handler.actions[PluginToRuntimeAction.RUN_LIST.value] + + result = await run_list( + { + 'run_id': 'run_1', + 'conversation_id': 'conv_other', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code != 0 + assert 'not accessible' in result.message.lower() + + +@pytest.mark.asyncio +async def test_run_list_returns_scoped_runs(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_list': True}) + await _create_run(db_engine) + await _create_run(db_engine, run_id='run_other', conversation_id='conv_other') + handler = _handler(db_engine) + run_list = handler.actions[PluginToRuntimeAction.RUN_LIST.value] + + result = await run_list( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code == 0 + page = RunPage.model_validate(result.data) + assert [run.run_id for run in page.items] == ['run_1'] + + +@pytest.mark.asyncio +async def test_run_events_page_returns_events(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_events_page': True}) + await _create_run(db_engine) + handler = _handler(db_engine) + run_events_page = handler.actions[PluginToRuntimeAction.RUN_EVENTS_PAGE.value] + + result = await run_events_page( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code == 0 + page = RunEventPage.model_validate(result.data) + assert [item.sequence for item in page.items] == [1] + assert page.items[0].type == 'message.completed' + + +@pytest.mark.asyncio +async def test_run_get_uses_persistent_authorization_after_session_expired(db_engine): + await _create_run(db_engine, available_apis={'run_get': True}) + handler = _handler(db_engine) + run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value] + + result = await run_get( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code == 0 + run = AgentRun.model_validate(result.data) + assert run.run_id == 'run_1' + + +@pytest.mark.asyncio +async def test_persistent_run_get_rejects_cross_scope(db_engine): + await _create_run(db_engine, available_apis={'run_get': True}) + await _create_run( + db_engine, + run_id='run_other', + conversation_id='conv_other', + available_apis={'run_get': True}, + ) + handler = _handler(db_engine) + run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value] + + result = await run_get( + { + 'run_id': 'run_1', + 'target_run_id': 'run_other', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code != 0 + assert 'not accessible' in result.message.lower() + + +@pytest.mark.asyncio +async def test_persistent_run_get_requires_capability(db_engine): + await _create_run(db_engine, available_apis={'run_get': False}) + handler = _handler(db_engine) + run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value] + + result = await run_get( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code != 0 + assert 'not authorized' in result.message.lower() + + +@pytest.mark.asyncio +async def test_persistent_authorization_does_not_reopen_artifact_api(db_engine): + await _create_run(db_engine, available_apis={'artifact_metadata': True}) + handler = _handler(db_engine) + artifact_metadata = handler.actions[PluginToRuntimeAction.ARTIFACT_METADATA.value] + + result = await artifact_metadata( + { + 'run_id': 'run_1', + 'artifact_id': 'artifact_1', + 'caller_plugin_identity': 'test/runner', + } + ) + + assert result.code != 0 + assert 'not found or expired' in result.message.lower() + + +@pytest.mark.asyncio +async def test_run_cancel_basic_path(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_cancel': True}) + await _create_run(db_engine) + handler = _handler(db_engine) + run_cancel = handler.actions[PluginToRuntimeAction.RUN_CANCEL.value] + + result = await run_cancel( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'reason': 'user requested', + } + ) + + assert result.code == 0 + run = AgentRun.model_validate(result.data) + assert run.run_id == 'run_1' + assert run.cancel_requested_at is not None + assert run.status_reason == 'user requested' + + +@pytest.mark.asyncio +async def test_run_append_result_basic_path(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_append_result': True}) + await _create_run(db_engine) + handler = _handler(db_engine) + run_append_result = handler.actions[PluginToRuntimeAction.RUN_APPEND_RESULT.value] + + result = await run_append_result( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'sequence': 2, + 'result': { + 'type': 'message.delta', + 'data': {'delta': 'hello'}, + 'usage': {'output_tokens': 1}, + }, + } + ) + + assert result.code == 0 + event = AgentRunEvent.model_validate(result.data) + assert event.run_id == 'run_1' + assert event.sequence == 2 + assert event.type == 'message.delta' + assert event.data == {'delta': 'hello'} + assert event.usage == {'output_tokens': 1} + + +@pytest.mark.asyncio +async def test_run_finalize_basic_path(session_registry, db_engine): + await _register_session(session_registry, available_apis={'run_finalize': True}) + await _create_run(db_engine) + handler = _handler(db_engine) + run_finalize = handler.actions[PluginToRuntimeAction.RUN_FINALIZE.value] + + result = await run_finalize( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'status': 'completed', + 'status_reason': 'done', + 'usage': {'total_tokens': 3}, + } + ) + + assert result.code == 0 + run = AgentRun.model_validate(result.data) + assert run.status == 'completed' + assert run.status_reason == 'done' + assert run.finished_at is not None + assert run.usage == {'total_tokens': 3} + + +@pytest.mark.asyncio +async def test_runtime_register_heartbeat_and_list_actions(session_registry, db_engine): + await _register_session( + session_registry, + available_apis={ + 'runtime_register': True, + 'runtime_heartbeat': True, + 'runtime_list': True, + }, + ) + handler = _handler(db_engine) + runtime_register = handler.actions[PluginToRuntimeAction.RUNTIME_REGISTER.value] + runtime_heartbeat = handler.actions[PluginToRuntimeAction.RUNTIME_HEARTBEAT.value] + runtime_list = handler.actions[PluginToRuntimeAction.RUNTIME_LIST.value] + + registered = await runtime_register( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'runtime_id': 'runtime_1', + 'display_name': 'Runtime 1', + 'capabilities': {'runner': True}, + 'labels': {'region': 'test'}, + 'metadata': {'slots': 2}, + } + ) + + assert registered.code == 0 + assert registered.data['runtime_id'] == 'runtime_1' + assert registered.data['capabilities'] == {'runner': True} + + heartbeat = await runtime_heartbeat( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'runtime_id': 'runtime_1', + 'capabilities': {'runner': True, 'stream': True}, + 'labels': {'region': 'test'}, + 'metadata': {'active_runs': 1}, + } + ) + + assert heartbeat.code == 0 + assert heartbeat.data['capabilities'] == {'runner': True, 'stream': True} + assert heartbeat.data['metadata'] == {'slots': 2, 'active_runs': 1} + + page = await runtime_list( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'statuses': ['online'], + 'labels': {'region': 'test'}, + } + ) + + assert page.code == 0 + assert [item['runtime_id'] for item in page.data['items']] == ['runtime_1'] + + +@pytest.mark.asyncio +async def test_run_claim_renew_and_release_actions(session_registry, db_engine): + await _register_session( + session_registry, + available_apis={ + 'run_claim': True, + 'run_renew_claim': True, + 'run_release_claim': True, + }, + ) + await RunLedgerStore(db_engine).create_run( + run_id='queued_run', + event_id='evt_queued', + binding_id='binding_1', + runner_id='plugin:test/runner/default', + conversation_id='conv_1', + bot_id='bot_1', + workspace_id='workspace_1', + status='queued', + queue_name='default', + priority=5, + ) + handler = _handler(db_engine) + run_claim = handler.actions[PluginToRuntimeAction.RUN_CLAIM.value] + run_renew_claim = handler.actions[PluginToRuntimeAction.RUN_RENEW_CLAIM.value] + run_release_claim = handler.actions[PluginToRuntimeAction.RUN_RELEASE_CLAIM.value] + + claimed = await run_claim( + { + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + 'runtime_id': 'runtime_1', + 'queue_name': 'default', + 'lease_seconds': 30, + } + ) + + assert claimed.code == 0 + assert claimed.data['run_id'] == 'queued_run' + assert claimed.data['status'] == 'claimed' + assert claimed.data['claimed_by_runtime_id'] == 'runtime_1' + claim_token = claimed.data['claim_token'] + + renewed = await run_renew_claim( + { + 'run_id': 'run_1', + 'target_run_id': 'queued_run', + 'caller_plugin_identity': 'test/runner', + 'runtime_id': 'runtime_1', + 'claim_token': claim_token, + 'lease_seconds': 60, + } + ) + + assert renewed.code == 0 + assert renewed.data['claim_token'] == claim_token + + released = await run_release_claim( + { + 'run_id': 'run_1', + 'target_run_id': 'queued_run', + 'caller_plugin_identity': 'test/runner', + 'runtime_id': 'runtime_1', + 'claim_token': claim_token, + 'reason': 'done with lease', + } + ) + + assert released.code == 0 + assert released.data['status'] == 'queued' + assert released.data['claimed_by_runtime_id'] is None + assert released.data['claim_token'] is None diff --git a/tests/unit_tests/agent/test_run_ledger_store.py b/tests/unit_tests/agent/test_run_ledger_store.py new file mode 100644 index 00000000..84c35f3f --- /dev/null +++ b/tests/unit_tests/agent/test_run_ledger_store.py @@ -0,0 +1,167 @@ +"""Tests for RunLedgerStore host primitives.""" + +from __future__ import annotations + +import datetime + +import pytest +import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import sessionmaker + +from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore +from langbot.pkg.entity.persistence.agent_run import AgentRun +from langbot.pkg.entity.persistence.base import Base + + +UTC = datetime.timezone.utc + + +@pytest.fixture +async def db_engine(tmp_path): + db_path = tmp_path / 'run_ledger_store.db' + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + +@pytest.fixture +def store(db_engine): + return RunLedgerStore(db_engine) + + +@pytest.mark.asyncio +async def test_create_queued_run_claim_renew_release(store): + run = await store.create_run( + run_id='run-queued', + event_id='evt-1', + binding_id='binding-1', + runner_id='runner-a', + status='queued', + queue_name='default', + priority=10, + requested_runtime_id='runtime-a', + ) + + assert run['status'] == 'queued' + assert run['started_at'] is None + assert run['queue_name'] == 'default' + assert run['priority'] == 10 + assert run['requested_runtime_id'] == 'runtime-a' + + assert await store.claim_next_run(runtime_id='runtime-b', queue_name='default') is None + + claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=30) + + assert claimed is not None + assert claimed['run_id'] == 'run-queued' + assert claimed['status'] == 'claimed' + assert claimed['claimed_by_runtime_id'] == 'runtime-a' + assert claimed['claim_token'] + assert claimed['dispatch_attempts'] == 1 + assert claimed['claim_lease_expires_at'] is not None + assert claimed['last_claimed_at'] is not None + + token = claimed['claim_token'] + assert await store.renew_claim(run_id='run-queued', claim_token='wrong-token') is None + + renewed = await store.renew_claim(run_id='run-queued', claim_token=token, lease_seconds=90) + + assert renewed is not None + assert renewed['claim_token'] == token + assert renewed['claim_lease_expires_at'] >= claimed['claim_lease_expires_at'] + + released = await store.release_claim( + run_id='run-queued', + claim_token=token, + status='queued', + status_reason='runtime released capacity', + ) + + assert released is not None + assert released['status'] == 'queued' + assert released['status_reason'] == 'runtime released capacity' + assert released['claimed_by_runtime_id'] is None + assert released['claim_token'] is None + assert released['claim_lease_expires_at'] is None + assert released['dispatch_attempts'] == 1 + + +@pytest.mark.asyncio +async def test_expired_claim_can_be_reclaimed(store, db_engine): + await store.create_run( + run_id='run-expired', + event_id='evt-2', + binding_id='binding-1', + runner_id='runner-a', + status='queued', + queue_name='default', + ) + first_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60) + assert first_claim is not None + + session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with session_factory() as session: + await session.execute( + sqlalchemy.update(AgentRun) + .where(AgentRun.run_id == 'run-expired') + .values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1)) + ) + await session.commit() + + reclaimed = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60) + + assert reclaimed is not None + assert reclaimed['run_id'] == 'run-expired' + assert reclaimed['claimed_by_runtime_id'] == 'runtime-b' + assert reclaimed['claim_token'] != first_claim['claim_token'] + assert reclaimed['dispatch_attempts'] == 2 + + +@pytest.mark.asyncio +async def test_runtime_register_heartbeat_list_and_mark_stale(store): + registered = await store.register_runtime( + runtime_id='runtime-a', + display_name='Runtime A', + endpoint='http://runtime-a', + version='1.0.0', + capabilities={'stream': True}, + labels={'region': 'test'}, + metadata={'slot_count': 2}, + heartbeat_deadline_seconds=30, + ) + + assert registered['runtime_id'] == 'runtime-a' + assert registered['status'] == 'online' + assert registered['display_name'] == 'Runtime A' + assert registered['capabilities'] == {'stream': True} + assert registered['labels'] == {'region': 'test'} + assert registered['metadata'] == {'slot_count': 2} + assert registered['last_heartbeat_at'] is not None + assert registered['heartbeat_deadline_at'] is not None + + heartbeat = await store.heartbeat_runtime( + runtime_id='runtime-a', + metadata={'active_runs': 1}, + heartbeat_deadline_seconds=30, + ) + + assert heartbeat is not None + assert heartbeat['metadata'] == {'slot_count': 2, 'active_runs': 1} + + runtimes = await store.list_runtimes(statuses=['online']) + assert [runtime['runtime_id'] for runtime in runtimes] == ['runtime-a'] + + stale = await store.mark_stale_runtimes( + now=datetime.datetime.now(UTC) + datetime.timedelta(seconds=31), + ) + + assert [runtime['runtime_id'] for runtime in stale] == ['runtime-a'] + assert stale[0]['status'] == 'stale' + assert (await store.get_runtime('runtime-a'))['status'] == 'stale'