From ce007c49c8faee67460f433bf1858a43de2332e5 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Sat, 23 May 2026 21:45:11 +0800 Subject: [PATCH] feat(agent-runner): add persistent state APIs --- .../pkg/agent/runner/context_builder.py | 26 +- src/langbot/pkg/agent/runner/orchestrator.py | 101 +++- .../agent/runner/persistent_state_store.py | 522 +++++++++++++++++ .../pkg/agent/runner/session_registry.py | 21 +- .../entity/persistence/agent_runner_state.py | 89 +++ src/langbot/pkg/persistence/alembic/env.py | 1 + ..._add_agent_runner_state_table_for_host_.py | 68 +++ src/langbot/pkg/plugin/handler.py | 412 +++++++++++++- .../agent/test_context_builder_state.py | 361 ++++++++++++ .../agent/test_context_validation.py | 94 ++- tests/unit_tests/agent/test_state_api_auth.py | 538 ++++++++++++++++++ tests/unit_tests/agent/test_state_store.py | 236 +++++++- 12 files changed, 2407 insertions(+), 62 deletions(-) create mode 100644 src/langbot/pkg/agent/runner/persistent_state_store.py create mode 100644 src/langbot/pkg/entity/persistence/agent_runner_state.py create mode 100644 src/langbot/pkg/persistence/alembic/versions/6dfd3dd7f0c7_add_agent_runner_state_table_for_host_.py create mode 100644 tests/unit_tests/agent/test_context_builder_state.py create mode 100644 tests/unit_tests/agent/test_state_api_auth.py diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index f2957427..551d207e 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -13,6 +13,7 @@ from .descriptor import AgentRunnerDescriptor from .config_migration import ConfigMigration from .context_packager import AgentContextPackager from .state_store import get_state_store +from .persistent_state_store import get_persistent_state_store from . import events as runner_events from .host_models import AgentEventEnvelope, AgentBinding from .pipeline_compat_adapter import PipelineCompatAdapter @@ -259,11 +260,13 @@ class AgentRunContextBuilder: # Build context access (no history inlined by default for Protocol v1) # Populate with actual values from stores - context_access = await self._build_context_access(event, descriptor) + context_access = await self._build_context_access(event, descriptor, binding) - # Build state snapshot from event context - state_store = get_state_store() - state: AgentRunState = state_store.build_snapshot_from_event(event, binding, descriptor) + # Build state snapshot from persistent state store (event-first Protocol v1) + persistent_state_store = get_persistent_state_store( + self.ap.persistence_mgr.get_db_engine() + ) + state: AgentRunState = await persistent_state_store.build_snapshot_from_event(event, binding, descriptor) # Build runtime context runtime: AgentRuntimeContext = { @@ -420,6 +423,7 @@ class AgentRunContextBuilder: } # Build context access (for legacy, minimal API availability) + # Legacy Query-based mode does NOT have persistent state API context_access = { 'conversation_id': conversation.get('conversation_id') if conversation else None, 'thread_id': None, @@ -441,7 +445,7 @@ class AgentRunContextBuilder: 'event_page': False, 'artifact_metadata': False, 'artifact_read': False, - 'state': True, + 'state': False, # Legacy Query mode does not have persistent state API 'storage': True, }, } @@ -869,12 +873,14 @@ class AgentRunContextBuilder: self, event: AgentEventEnvelope, descriptor: AgentRunnerDescriptor, + binding: AgentBinding | None = None, ) -> dict[str, typing.Any]: """Build ContextAccess with actual values from stores. Args: event: Event envelope descriptor: Runner descriptor + binding: Agent binding (required for state_policy in event-first mode) Returns: ContextAccess dict @@ -895,6 +901,14 @@ class AgentRunContextBuilder: artifact_metadata_enabled = 'metadata' in artifact_permissions artifact_read_enabled = 'read' in artifact_permissions + # Determine state API availability based on binding state_policy (event-first mode) + # For legacy Query-based mode, state is NOT available (no persistent state API) + state_enabled = False + if binding is not None: + state_policy = binding.state_policy + if state_policy.enable_state and state_policy.state_scopes: + state_enabled = True + # Get latest cursor and has_history_before if conversation exists latest_cursor = None has_history_before = False @@ -931,7 +945,7 @@ class AgentRunContextBuilder: 'event_page': event_page_enabled, 'artifact_metadata': artifact_metadata_enabled, 'artifact_read': artifact_read_enabled, - 'state': True, + 'state': state_enabled, 'storage': True, }, } diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index 160d4538..1e403199 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -17,6 +17,7 @@ from .context_builder import AgentRunContextBuilder, AgentRunContextPayload from .resource_builder import AgentResourceBuilder from .result_normalizer import AgentResultNormalizer from .state_store import get_state_store, RunnerScopedStateStore +from .persistent_state_store import get_persistent_state_store, PersistentStateStore from .session_registry import get_session_registry, AgentRunSessionRegistry from .config_migration import ConfigMigration from .host_models import AgentEventEnvelope, AgentBinding @@ -63,6 +64,7 @@ class AgentRunOrchestrator: # Cached singleton references (set in __init__) _session_registry: AgentRunSessionRegistry _state_store: RunnerScopedStateStore + _persistent_state_store: PersistentStateStore | None def __init__( self, @@ -77,6 +79,7 @@ class AgentRunOrchestrator: # Cache singleton references to avoid per-request getter calls self._session_registry = get_session_registry() self._state_store = get_state_store() + self._persistent_state_store = None # Lazy init on first use async def run( self, @@ -122,6 +125,9 @@ class AgentRunOrchestrator: resources=resources, ) + # Build state context for State API handlers + state_context = self._build_state_context(event, binding, descriptor) + # Register session for proxy action permission validation run_id = context['run_id'] await self._session_registry.register( @@ -132,6 +138,11 @@ class AgentRunOrchestrator: resources=resources, permissions=descriptor.permissions or {}, conversation_id=event.conversation_id, + state_policy={ + 'enable_state': binding.state_policy.enable_state, + 'state_scopes': list(binding.state_policy.state_scopes), + }, + state_context=state_context, ) # Write incoming event to EventLog @@ -170,7 +181,7 @@ class AgentRunOrchestrator: # Handle state.updated first - consume before normalizer if result_dict.get('type') == 'state.updated': - self._handle_state_updated_event(result_dict, event, binding, descriptor) + await self._handle_state_updated_event(result_dict, event, binding, descriptor) # Pass to normalizer for logging, but don't yield to pipeline await self.result_normalizer.normalize(result_dict, descriptor) continue @@ -555,7 +566,7 @@ class AgentRunOrchestrator: f'artifact.created failed to register artifact: {e}', ) - def _handle_state_updated_event( + async def _handle_state_updated_event( self, result_dict: dict[str, typing.Any], event: AgentEventEnvelope, @@ -564,6 +575,8 @@ class AgentRunOrchestrator: ) -> None: """Handle state.updated result in event-first mode. + Persists state to database via PersistentStateStore. + Args: result_dict: Raw result dict with type='state.updated' event: Event envelope @@ -585,8 +598,14 @@ class AgentRunOrchestrator: ) return - # Apply update to state store using event context - success = self._state_store.apply_update_from_event( + # Lazy init persistent state store + if self._persistent_state_store is None: + self._persistent_state_store = get_persistent_state_store( + self.ap.persistence_mgr.get_db_engine() + ) + + # Apply update to persistent state store + success, error = await self._persistent_state_store.apply_update_from_event( event=event, binding=binding, descriptor=descriptor, @@ -600,7 +619,79 @@ class AgentRunOrchestrator: self.ap.logger.debug( f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}' ) - # Invalid scope or missing identity is already logged by apply_update_from_event + elif error: + self.ap.logger.warning( + f'Runner {descriptor.id} state.updated rejected: {error}' + ) + + def _build_state_context( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> dict[str, typing.Any]: + """Build state context for State API handlers. + + Returns context with: + - scope_keys: Dict mapping scope name to scope_key + - binding_identity: Binding identity for state isolation + - Additional context fields for DB insert + """ + # Get binding identity + binding_identity = binding.binding_id + if not binding_identity: + scope = binding.scope + if scope.scope_type and scope.scope_id: + binding_identity = f"{scope.scope_type}:{scope.scope_id}" + else: + binding_identity = "unknown_binding" + + # Build scope keys for each scope + scope_keys: dict[str, str] = {} + + # Conversation scope + if event.conversation_id: + parts = [descriptor.id, binding_identity, event.conversation_id] + if event.thread_id: + parts.append(event.thread_id) + scope_keys['conversation'] = f'conversation:{":".join(parts)}' + + # Actor scope + if event.actor and event.actor.actor_id: + parts = [ + descriptor.id, + binding_identity, + event.actor.actor_type or 'user', + event.actor.actor_id, + ] + scope_keys['actor'] = f'actor:{":".join(parts)}' + + # Subject scope + if event.subject and event.subject.subject_id: + parts = [ + descriptor.id, + binding_identity, + event.subject.subject_type or 'unknown', + event.subject.subject_id, + ] + scope_keys['subject'] = f'subject:{":".join(parts)}' + + # Runner scope (always available) + parts = [descriptor.id, binding_identity] + scope_keys['runner'] = f'runner:{":".join(parts)}' + + return { + 'scope_keys': scope_keys, + 'binding_identity': binding_identity, + 'bot_id': event.bot_id, + 'workspace_id': event.workspace_id, + 'conversation_id': event.conversation_id, + 'thread_id': event.thread_id, + 'actor_type': event.actor.actor_type if event.actor else None, + 'actor_id': event.actor.actor_id if event.actor else None, + 'subject_type': event.subject.subject_type if event.subject else None, + 'subject_id': event.subject.subject_id if event.subject else None, + } async def _write_event_log( self, diff --git a/src/langbot/pkg/agent/runner/persistent_state_store.py b/src/langbot/pkg/agent/runner/persistent_state_store.py new file mode 100644 index 00000000..6bf8d850 --- /dev/null +++ b/src/langbot/pkg/agent/runner/persistent_state_store.py @@ -0,0 +1,522 @@ +"""Persistent state store for AgentRunner protocol state. + +This module provides a database-backed state store for event-first Protocol v1, +while preserving in-memory state store for legacy Query-based flow. +""" +from __future__ import annotations + +import typing +import json +import asyncio +import threading +from datetime import datetime + +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy import select, delete, update + +from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query + +from .descriptor import AgentRunnerDescriptor +from .host_models import AgentEventEnvelope, AgentBinding +from ...entity.persistence.agent_runner_state import AgentRunnerState + + +# Valid state scopes for agent runner state updates. +VALID_STATE_SCOPES = ('conversation', 'actor', 'subject', 'runner') + +# Key mapping for backward compatibility +LEGACY_KEY_MAPPING = { + 'conversation_id': 'external.conversation_id', +} + +# Maximum value_json size (256KB) +MAX_VALUE_JSON_BYTES = 256 * 1024 + + +class PersistentStateStore: + """Database-backed state store for AgentRunner protocol state. + + IMPORTANT: This is HOST-OWNED protocol state, NOT plugin instance state. + + This store provides: + 1. Persistent storage across runs via database + 2. Scope isolation by runner_id + binding_identity + scope + 3. Policy enforcement (enable_state, state_scopes) + 4. JSON value validation and size limits + + Used by: + - Event-first Protocol v1 (async methods) + - State API handlers (get/set/delete/list) + """ + + def __init__(self, db_engine: AsyncEngine): + self._db_engine = db_engine + + # ========== Scope Key Building (shared with in-memory store) ========== + + def _get_binding_identity(self, binding: AgentBinding) -> str: + """Get stable binding identity for scope key.""" + if binding.binding_id: + return binding.binding_id + scope = binding.scope + if scope.scope_type and scope.scope_id: + return f"{scope.scope_type}:{scope.scope_id}" + return "unknown_binding" + + def _make_conversation_scope_key( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> str | None: + """Build conversation scope key from event and binding.""" + if not event.conversation_id: + return None + + binding_identity = self._get_binding_identity(binding) + parts = [ + descriptor.id, + binding_identity, + event.conversation_id, + ] + if event.thread_id: + parts.append(event.thread_id) + return f'conversation:{":".join(parts)}' + + def _make_actor_scope_key( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> str | None: + """Build actor scope key from event and binding.""" + if not event.actor or not event.actor.actor_id: + return None + + binding_identity = self._get_binding_identity(binding) + parts = [ + descriptor.id, + binding_identity, + event.actor.actor_type or 'user', + event.actor.actor_id, + ] + return f'actor:{":".join(parts)}' + + def _make_subject_scope_key( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> str | None: + """Build subject scope key from event and binding.""" + if not event.subject or not event.subject.subject_id: + return None + + binding_identity = self._get_binding_identity(binding) + parts = [ + descriptor.id, + binding_identity, + event.subject.subject_type or 'unknown', + event.subject.subject_id, + ] + return f'subject:{":".join(parts)}' + + def _make_runner_scope_key( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> str: + """Build runner scope key from event and binding.""" + binding_identity = self._get_binding_identity(binding) + parts = [ + descriptor.id, + binding_identity, + ] + return f'runner:{":".join(parts)}' + + def _get_scope_key( + self, + scope: str, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> str | None: + """Get scope key for given scope.""" + if scope == 'conversation': + return self._make_conversation_scope_key(event, binding, descriptor) + elif scope == 'actor': + return self._make_actor_scope_key(event, binding, descriptor) + elif scope == 'subject': + return self._make_subject_scope_key(event, binding, descriptor) + elif scope == 'runner': + return self._make_runner_scope_key(event, binding, descriptor) + return None + + def _check_scope_enabled(self, scope: str, binding: AgentBinding) -> bool: + """Check if scope is enabled by binding's state_policy.""" + state_policy = binding.state_policy + if not state_policy.enable_state: + return False + return scope in state_policy.state_scopes + + def _validate_json_value( + self, + value: typing.Any, + logger: typing.Any = None, + ) -> tuple[str | None, str | None]: + """Validate and serialize value to JSON. + + Returns: + Tuple of (json_string, error_message). If error_message is not None, + json_string will be None. + """ + try: + json_str = json.dumps(value, ensure_ascii=False) + except (TypeError, ValueError) as e: + return None, f'Value is not JSON-serializable: {e}' + + # Check size limit + json_bytes = len(json_str.encode('utf-8')) + if json_bytes > MAX_VALUE_JSON_BYTES: + return None, f'Value size {json_bytes} bytes exceeds limit {MAX_VALUE_JSON_BYTES} bytes' + + return json_str, None + + # ========== Async DB Operations ========== + + async def build_snapshot_from_event( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + ) -> dict[str, dict[str, typing.Any]]: + """Build state snapshot for all scopes from event and binding. + + Reads from database, respects state_policy. + """ + state_policy = binding.state_policy + + # If state is disabled, return all empty scopes + if not state_policy.enable_state: + return { + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + } + + snapshot: dict[str, dict[str, typing.Any]] = { + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + } + + async with self._db_engine.connect() as conn: + for scope in VALID_STATE_SCOPES: + if not self._check_scope_enabled(scope, binding): + continue + + scope_key = self._get_scope_key(scope, event, binding, descriptor) + if not scope_key: + continue + + # Query all state entries for this scope_key + result = await conn.execute( + select(AgentRunnerState.state_key, AgentRunnerState.value_json) + .where(AgentRunnerState.scope_key == scope_key) + ) + rows = result.fetchall() + + for row in rows: + key = row.state_key + value_json = row.value_json + if value_json: + try: + snapshot[scope][key] = json.loads(value_json) + except json.JSONDecodeError: + pass # Skip invalid JSON + + # Seed external.conversation_id from event.conversation_id if not set + if self._check_scope_enabled('conversation', binding) and event.conversation_id: + if 'external.conversation_id' not in snapshot['conversation']: + snapshot['conversation']['external.conversation_id'] = event.conversation_id + + return snapshot + + async def apply_update_from_event( + self, + event: AgentEventEnvelope, + binding: AgentBinding, + descriptor: AgentRunnerDescriptor, + scope: str, + key: str, + value: typing.Any, + logger: typing.Any = None, + ) -> tuple[bool, str | None]: + """Apply a state update from event context. + + Returns: + Tuple of (success, error_message). If success is False, error_message + contains the reason. + """ + state_policy = binding.state_policy + + # Check if state is disabled + if not state_policy.enable_state: + return False, 'State is disabled by binding policy' + + # Validate scope + if scope not in VALID_STATE_SCOPES: + return False, f'Invalid scope: {scope}' + + # Check if scope is enabled + if not self._check_scope_enabled(scope, binding): + return False, f'Scope "{scope}" not enabled by binding policy' + + # Map legacy key names + if key in LEGACY_KEY_MAPPING: + key = LEGACY_KEY_MAPPING[key] + + # Get scope key + scope_key = self._get_scope_key(scope, event, binding, descriptor) + if not scope_key: + return False, f'Missing identity for scope "{scope}"' + + # Validate and serialize value + value_json, error = self._validate_json_value(value, logger) + if error: + return False, error + + # Build context fields + binding_identity = self._get_binding_identity(binding) + + async with self._db_engine.begin() as conn: + # Check if entry exists + result = await conn.execute( + select(AgentRunnerState.id) + .where(AgentRunnerState.scope_key == scope_key) + .where(AgentRunnerState.state_key == key) + ) + existing = result.first() + + now = datetime.utcnow() + + if existing: + # Update existing entry + await conn.execute( + update(AgentRunnerState) + .where(AgentRunnerState.id == existing.id) + .values( + value_json=value_json, + updated_at=now, + ) + ) + else: + # Insert new entry + await conn.execute( + sqlalchemy.insert(AgentRunnerState).values( + runner_id=descriptor.id, + binding_identity=binding_identity, + scope=scope, + scope_key=scope_key, + state_key=key, + value_json=value_json, + bot_id=event.bot_id, + workspace_id=event.workspace_id, + conversation_id=event.conversation_id, + thread_id=event.thread_id, + actor_type=event.actor.actor_type if event.actor else None, + actor_id=event.actor.actor_id if event.actor else None, + subject_type=event.subject.subject_type if event.subject else None, + subject_id=event.subject.subject_id if event.subject else None, + created_at=now, + updated_at=now, + ) + ) + + return True, None + + async def state_get( + self, + scope_key: str, + state_key: str, + ) -> typing.Any: + """Get a single state value by scope_key and state_key. + + Used by State API handlers. + """ + async with self._db_engine.connect() as conn: + result = await conn.execute( + select(AgentRunnerState.value_json) + .where(AgentRunnerState.scope_key == scope_key) + .where(AgentRunnerState.state_key == state_key) + ) + row = result.first() + + if not row or not row.value_json: + return None + + try: + return json.loads(row.value_json) + except json.JSONDecodeError: + return None + + async def state_set( + self, + scope_key: str, + state_key: str, + value: typing.Any, + runner_id: str, + binding_identity: str, + scope: str, + context: dict[str, typing.Any] | None = None, + logger: typing.Any = None, + ) -> tuple[bool, str | None]: + """Set a state value. + + Used by State API handlers. + Context contains optional fields like bot_id, conversation_id, etc. + """ + # Validate and serialize value + value_json, error = self._validate_json_value(value, logger) + if error: + return False, error + + context = context or {} + + async with self._db_engine.begin() as conn: + # Check if entry exists + result = await conn.execute( + select(AgentRunnerState.id) + .where(AgentRunnerState.scope_key == scope_key) + .where(AgentRunnerState.state_key == state_key) + ) + existing = result.first() + + now = datetime.utcnow() + + if existing: + # Update existing entry + await conn.execute( + update(AgentRunnerState) + .where(AgentRunnerState.id == existing.id) + .values( + value_json=value_json, + updated_at=now, + ) + ) + else: + # Insert new entry + await conn.execute( + sqlalchemy.insert(AgentRunnerState).values( + runner_id=runner_id, + binding_identity=binding_identity, + scope=scope, + scope_key=scope_key, + state_key=state_key, + value_json=value_json, + bot_id=context.get('bot_id'), + workspace_id=context.get('workspace_id'), + conversation_id=context.get('conversation_id'), + thread_id=context.get('thread_id'), + actor_type=context.get('actor_type'), + actor_id=context.get('actor_id'), + subject_type=context.get('subject_type'), + subject_id=context.get('subject_id'), + created_at=now, + updated_at=now, + ) + ) + + return True, None + + async def state_delete( + self, + scope_key: str, + state_key: str, + ) -> bool: + """Delete a state value. + + Returns True if deleted, False if not found. + """ + async with self._db_engine.begin() as conn: + result = await conn.execute( + delete(AgentRunnerState) + .where(AgentRunnerState.scope_key == scope_key) + .where(AgentRunnerState.state_key == state_key) + .returning(AgentRunnerState.id) + ) + deleted = result.first() + return deleted is not None + + async def state_list( + self, + scope_key: str, + prefix: str | None = None, + limit: int = 100, + ) -> tuple[list[str], bool]: + """List state keys in a scope. + + Returns tuple of (keys, has_more). + """ + # Enforce limit cap + limit = min(limit, 100) + + async with self._db_engine.connect() as conn: + query = ( + select(AgentRunnerState.state_key) + .where(AgentRunnerState.scope_key == scope_key) + .order_by(AgentRunnerState.state_key) + .limit(limit + 1) # Fetch one extra to check has_more + ) + + if prefix: + query = query.where( + AgentRunnerState.state_key.like(f'{prefix}%') + ) + + result = await conn.execute(query) + rows = result.fetchall() + + keys = [row.state_key for row in rows[:limit]] + has_more = len(rows) > limit + + return keys, has_more + + async def clear_all(self) -> None: + """Clear all state entries (for testing).""" + async with self._db_engine.begin() as conn: + await conn.execute(delete(AgentRunnerState)) + + +# Global singleton persistent state store +_persistent_state_store: PersistentStateStore | None = None +_persistent_state_store_lock = threading.Lock() + + +def get_persistent_state_store(db_engine: AsyncEngine | None = None) -> PersistentStateStore: + """Get the global persistent state store singleton. + + Args: + db_engine: Database engine (required on first call) + + Returns: + PersistentStateStore singleton + """ + global _persistent_state_store + with _persistent_state_store_lock: + if _persistent_state_store is None: + if db_engine is None: + raise RuntimeError("db_engine required for first call to get_persistent_state_store") + _persistent_state_store = PersistentStateStore(db_engine) + return _persistent_state_store + + +def reset_persistent_state_store() -> None: + """Reset the global persistent state store (for testing).""" + global _persistent_state_store + with _persistent_state_store_lock: + _persistent_state_store = None diff --git a/src/langbot/pkg/agent/runner/session_registry.py b/src/langbot/pkg/agent/runner/session_registry.py index 6a0dca3e..1f6ae876 100644 --- a/src/langbot/pkg/agent/runner/session_registry.py +++ b/src/langbot/pkg/agent/runner/session_registry.py @@ -28,6 +28,8 @@ class AgentRunSession(typing.TypedDict): conversation_id: Conversation ID for history/event access resources: Authorized resources for this run (from AgentResources) permissions: Runner permissions from descriptor (artifacts, history, events, etc.) + state_policy: State policy from binding (enable_state, state_scopes) + state_context: Context for state API (scope_keys, binding_identity, etc.) status: Session status tracking _authorized_ids: Pre-computed authorized resource IDs for O(1) lookup """ @@ -38,6 +40,8 @@ class AgentRunSession(typing.TypedDict): conversation_id: str | None resources: AgentResources permissions: dict[str, list[str]] + state_policy: dict[str, typing.Any] # {enable_state: bool, state_scopes: list} + state_context: dict[str, typing.Any] # {scope_keys: dict, binding_identity: str, ...} status: AgentRunSessionStatus _authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup @@ -70,6 +74,8 @@ class AgentRunSessionRegistry: resources: AgentResources, conversation_id: str | None = None, permissions: dict[str, list[str]] | None = None, + state_policy: dict[str, typing.Any] | None = None, + state_context: dict[str, typing.Any] | None = None, ) -> None: """Register a new agent run session. @@ -81,12 +87,21 @@ class AgentRunSessionRegistry: resources: Authorized resources for this run conversation_id: Conversation ID for history/event access permissions: Runner permissions from descriptor (artifacts, history, events, etc.) + state_policy: State policy from binding (enable_state, state_scopes) + state_context: Context for state API (scope_keys, binding_identity, etc.) """ now = int(time.time()) # Normalize permissions to empty dict if None permissions = permissions or {} + # Normalize state_policy to defaults if None + if state_policy is None: + state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']} + + # Normalize state_context to empty dict if None + state_context = state_context or {} + # Pre-compute authorized resource IDs for O(1) lookup authorized_ids: dict[str, set[str]] = { 'model': {m.get('model_id') for m in resources.get('models', [])}, @@ -95,14 +110,18 @@ class AgentRunSessionRegistry: 'file': {f.get('file_id') for f in resources.get('files', [])}, } + # NOTE: state_policy and state_context are stored at session top-level, + # NOT in resources. Resources should only contain resource authorization info. session: AgentRunSession = { 'run_id': run_id, 'runner_id': runner_id, 'query_id': query_id, 'plugin_identity': plugin_identity, 'conversation_id': conversation_id, - 'resources': resources, + 'resources': resources, # Original AgentResources, no state metadata mixed in 'permissions': permissions, + 'state_policy': state_policy, + 'state_context': state_context, 'status': { 'started_at': now, 'last_activity_at': now, diff --git a/src/langbot/pkg/entity/persistence/agent_runner_state.py b/src/langbot/pkg/entity/persistence/agent_runner_state.py new file mode 100644 index 00000000..adc71ff8 --- /dev/null +++ b/src/langbot/pkg/entity/persistence/agent_runner_state.py @@ -0,0 +1,89 @@ +"""Agent runner state persistence entity for host-owned state.""" +from __future__ import annotations + +import sqlalchemy +import datetime + +from .base import Base + + +class AgentRunnerState(Base): + """AgentRunnerState stores host-owned state for AgentRunner protocol. + + State is: + - Host-owned: Managed by LangBot, not by plugin instances + - Scope-isolated: Separated by runner_id + binding_identity + scope + - Policy-enforced: Controlled by StatePolicy (enable_state, state_scopes) + + Scope key design: + - conversation: runner_id + binding_id + conversation_id [+ thread_id] + - actor: runner_id + binding_id + actor_type + actor_id + - subject: runner_id + binding_id + subject_type + subject_id + - runner: runner_id + binding_id + + This table persists state across runs, replacing the in-memory + RunnerScopedStateStore._store dict. + """ + + __tablename__ = 'agent_runner_state' + + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + """Auto-increment ID for sequencing.""" + + # Identity + runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True) + """Runner descriptor ID (plugin:author/name/runner).""" + + binding_identity = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True) + """Binding identity for isolation (binding_id or scope_type:scope_id).""" + + scope = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True) + """State scope: 'conversation', 'actor', 'subject', or 'runner'.""" + + scope_key = sqlalchemy.Column(sqlalchemy.String(512), nullable=False, index=True) + """Full scope key for unique lookup (includes all identity parts).""" + + state_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + """State key within scope (should use namespace prefix like external.*).""" + + value_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """State value as JSON string (size-limited by host).""" + + # Context fields for querying/filtering + bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Bot UUID if applicable.""" + + workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Workspace ID for multi-tenant.""" + + conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Conversation ID for conversation scope.""" + + thread_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Thread ID for thread-scoped conversation state.""" + + actor_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True) + """Actor type for actor scope.""" + + actor_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Actor ID for actor scope.""" + + subject_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True) + """Subject type for subject scope.""" + + subject_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Subject ID for subject scope.""" + + # Lifecycle + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow) + """When this state entry was created.""" + + updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow) + """When this state entry was last updated.""" + + # Unique constraint: scope_key + state_key + __table_args__ = ( + sqlalchemy.UniqueConstraint('scope_key', 'state_key', name='uq_agent_runner_state_scope_key_state_key'), + sqlalchemy.Index('ix_agent_runner_state_runner_binding', 'runner_id', 'binding_identity'), + sqlalchemy.Index('ix_agent_runner_state_scope_key_lookup', 'scope_key'), + ) diff --git a/src/langbot/pkg/persistence/alembic/env.py b/src/langbot/pkg/persistence/alembic/env.py index 6cb6d5b0..ec76d8e9 100644 --- a/src/langbot/pkg/persistence/alembic/env.py +++ b/src/langbot/pkg/persistence/alembic/env.py @@ -16,6 +16,7 @@ from langbot.pkg.entity.persistence.base import Base # Import all ORM models so they are registered with Base.metadata # This is required for autogenerate to detect model changes from langbot.pkg.entity.persistence import ( + agent_runner_state, apikey, artifact, bot, diff --git a/src/langbot/pkg/persistence/alembic/versions/6dfd3dd7f0c7_add_agent_runner_state_table_for_host_.py b/src/langbot/pkg/persistence/alembic/versions/6dfd3dd7f0c7_add_agent_runner_state_table_for_host_.py new file mode 100644 index 00000000..06551664 --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/6dfd3dd7f0c7_add_agent_runner_state_table_for_host_.py @@ -0,0 +1,68 @@ +# Alembic script.py.mako — template for auto-generated revisions +"""add agent_runner_state table for host-owned persistent state + +Revision ID: 6dfd3dd7f0c7 +Revises: a1b2c3d4e5f6 +Create Date: 2026-05-23 19:49:08.529110 +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers +revision = '6dfd3dd7f0c7' +down_revision = 'a1b2c3d4e5f6' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('agent_runner_state', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('runner_id', sa.String(length=255), nullable=False), + sa.Column('binding_identity', sa.String(length=255), nullable=False), + sa.Column('scope', sa.String(length=50), nullable=False), + sa.Column('scope_key', sa.String(length=512), nullable=False), + sa.Column('state_key', sa.String(length=255), nullable=False), + sa.Column('value_json', sa.Text(), nullable=True), + sa.Column('bot_id', sa.String(length=255), nullable=True), + sa.Column('workspace_id', sa.String(length=255), nullable=True), + sa.Column('conversation_id', sa.String(length=255), nullable=True), + sa.Column('thread_id', sa.String(length=255), nullable=True), + sa.Column('actor_type', sa.String(length=50), nullable=True), + sa.Column('actor_id', sa.String(length=255), nullable=True), + sa.Column('subject_type', sa.String(length=50), nullable=True), + sa.Column('subject_id', sa.String(length=255), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('scope_key', 'state_key', name='uq_agent_runner_state_scope_key_state_key') + ) + with op.batch_alter_table('agent_runner_state', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_agent_runner_state_actor_id'), ['actor_id'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_binding_identity'), ['binding_identity'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_bot_id'), ['bot_id'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_conversation_id'), ['conversation_id'], unique=False) + batch_op.create_index('ix_agent_runner_state_runner_binding', ['runner_id', 'binding_identity'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_runner_id'), ['runner_id'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_scope'), ['scope'], unique=False) + batch_op.create_index(batch_op.f('ix_agent_runner_state_scope_key'), ['scope_key'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('agent_runner_state', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_agent_runner_state_scope_key')) + batch_op.drop_index(batch_op.f('ix_agent_runner_state_scope')) + batch_op.drop_index(batch_op.f('ix_agent_runner_state_runner_id')) + batch_op.drop_index('ix_agent_runner_state_runner_binding') + batch_op.drop_index(batch_op.f('ix_agent_runner_state_conversation_id')) + batch_op.drop_index(batch_op.f('ix_agent_runner_state_bot_id')) + batch_op.drop_index(batch_op.f('ix_agent_runner_state_binding_identity')) + batch_op.drop_index(batch_op.f('ix_agent_runner_state_actor_id')) + + op.drop_table('agent_runner_state') + # ### end Alembic commands ### diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 2c6ca1a2..d83726e4 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -1348,10 +1348,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1420,10 +1424,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1483,10 +1491,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1538,10 +1550,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1610,10 +1626,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1694,10 +1714,14 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Run session {run_id} not found or expired' ) - # Validate caller plugin identity - if caller_plugin_identity: - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: return handler.ActionResponse.error( message=f'Plugin identity mismatch for run_id {run_id}' ) @@ -1746,6 +1770,346 @@ class RuntimeConnectionHandler(handler.Handler): self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'Artifact read error: {e}') + # ================= State APIs (run-scoped, policy-enforced) ================= + + @self.action(PluginToRuntimeAction.STATE_GET) + async def state_get(data: dict[str, Any]) -> handler.ActionResponse: + """Get a state value from host-owned state store. + + Requires run_id authorization and scope enabled by state_policy. + """ + run_id = data.get('run_id') + scope = data.get('scope') + key = data.get('key') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not scope: + return handler.ActionResponse.error(message='scope is required') + + if not key: + return handler.ActionResponse.error(message='key is required') + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Get state policy from session (stored in state_policy field, not in resources) + state_policy = session.get('state_policy', {}) + if not state_policy: + # Default state policy + state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']} + + # Check if state is enabled + if not state_policy.get('enable_state', True): + return handler.ActionResponse.error( + message='State access is disabled by binding policy' + ) + + # Check if scope is enabled + state_scopes = state_policy.get('state_scopes', ['conversation', 'actor']) + if scope not in state_scopes: + return handler.ActionResponse.error( + message=f'Scope "{scope}" is not enabled by binding policy' + ) + + # Build scope key using state_context from session (stored in state_context field, not in resources) + state_context = session.get('state_context', {}) + scope_key = state_context.get('scope_keys', {}).get(scope) + + if not scope_key: + return handler.ActionResponse.error( + message=f'Scope key not available for scope "{scope}"' + ) + + # Get state from persistent store + from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) + + try: + value = await store.state_get(scope_key, key) + return handler.ActionResponse.success(data={'value': value}) + except Exception as e: + self.ap.logger.error(f'STATE_GET error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'State get error: {e}') + + @self.action(PluginToRuntimeAction.STATE_SET) + async def state_set(data: dict[str, Any]) -> handler.ActionResponse: + """Set a state value in host-owned state store. + + Requires run_id authorization and scope enabled by state_policy. + Value must be JSON-serializable and size-limited. + """ + run_id = data.get('run_id') + scope = data.get('scope') + key = data.get('key') + value = data.get('value') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not scope: + return handler.ActionResponse.error(message='scope is required') + + if not key: + return handler.ActionResponse.error(message='key is required') + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Get state policy from session (stored in state_policy field, not in resources) + state_policy = session.get('state_policy', {}) + if not state_policy: + state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']} + + # Check if state is enabled + if not state_policy.get('enable_state', True): + return handler.ActionResponse.error( + message='State access is disabled by binding policy' + ) + + # Check if scope is enabled + state_scopes = state_policy.get('state_scopes', ['conversation', 'actor']) + if scope not in state_scopes: + return handler.ActionResponse.error( + message=f'Scope "{scope}" is not enabled by binding policy' + ) + + # Build scope key using state_context from session (stored in state_context field, not in resources) + state_context = session.get('state_context', {}) + scope_key = state_context.get('scope_keys', {}).get(scope) + + if not scope_key: + return handler.ActionResponse.error( + message=f'Scope key not available for scope "{scope}"' + ) + + # Get additional context for DB insert + runner_id = session.get('runner_id', '') + binding_identity = state_context.get('binding_identity', 'unknown') + + # Set state in persistent store + from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) + + try: + success, error = await store.state_set( + scope_key=scope_key, + state_key=key, + value=value, + runner_id=runner_id, + binding_identity=binding_identity, + scope=scope, + context=state_context, + logger=self.ap.logger, + ) + + if not success: + return handler.ActionResponse.error(message=error or 'Failed to set state') + + return handler.ActionResponse.success(data={'success': True}) + except Exception as e: + self.ap.logger.error(f'STATE_SET error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'State set error: {e}') + + @self.action(PluginToRuntimeAction.STATE_DELETE) + async def state_delete(data: dict[str, Any]) -> handler.ActionResponse: + """Delete a state value from host-owned state store. + + Requires run_id authorization and scope enabled by state_policy. + """ + run_id = data.get('run_id') + scope = data.get('scope') + key = data.get('key') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not scope: + return handler.ActionResponse.error(message='scope is required') + + if not key: + return handler.ActionResponse.error(message='key is required') + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Get state policy from session (stored in state_policy field, not in resources) + state_policy = session.get('state_policy', {}) + if not state_policy: + state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']} + + # Check if state is enabled + if not state_policy.get('enable_state', True): + return handler.ActionResponse.error( + message='State access is disabled by binding policy' + ) + + # Check if scope is enabled + state_scopes = state_policy.get('state_scopes', ['conversation', 'actor']) + if scope not in state_scopes: + return handler.ActionResponse.error( + message=f'Scope "{scope}" is not enabled by binding policy' + ) + + # Build scope key using state_context from session (stored in state_context field, not in resources) + state_context = session.get('state_context', {}) + scope_key = state_context.get('scope_keys', {}).get(scope) + + if not scope_key: + return handler.ActionResponse.error( + message=f'Scope key not available for scope "{scope}"' + ) + + # Delete state from persistent store + from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) + + try: + deleted = await store.state_delete(scope_key, key) + return handler.ActionResponse.success(data={'success': deleted}) + except Exception as e: + self.ap.logger.error(f'STATE_DELETE error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'State delete error: {e}') + + @self.action(PluginToRuntimeAction.STATE_LIST) + async def state_list(data: dict[str, Any]) -> handler.ActionResponse: + """List state keys in a scope. + + Requires run_id authorization and scope enabled by state_policy. + """ + run_id = data.get('run_id') + scope = data.get('scope') + prefix = data.get('prefix') + limit = data.get('limit', 100) + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not scope: + return handler.ActionResponse.error(message='scope is required') + + # Validate limit + if not isinstance(limit, int) or limit <= 0: + limit = 100 + limit = min(limit, 100) # Cap at 100 + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity (strict: required when session has plugin_identity) + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Get state policy from session (stored in state_policy field, not in resources) + state_policy = session.get('state_policy', {}) + if not state_policy: + state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']} + + # Check if state is enabled + if not state_policy.get('enable_state', True): + return handler.ActionResponse.error( + message='State access is disabled by binding policy' + ) + + # Check if scope is enabled + state_scopes = state_policy.get('state_scopes', ['conversation', 'actor']) + if scope not in state_scopes: + return handler.ActionResponse.error( + message=f'Scope "{scope}" is not enabled by binding policy' + ) + + # Build scope key using state_context from session (stored in state_context field, not in resources) + state_context = session.get('state_context', {}) + scope_key = state_context.get('scope_keys', {}).get(scope) + + if not scope_key: + return handler.ActionResponse.error( + message=f'Scope key not available for scope "{scope}"' + ) + + # List state keys from persistent store + from ..agent.runner.persistent_state_store import get_persistent_state_store + store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) + + try: + keys, has_more = await store.state_list(scope_key, prefix, limit) + return handler.ActionResponse.success(data={ + 'keys': keys, + 'has_more': has_more, + }) + except Exception as e: + self.ap.logger.error(f'STATE_LIST error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'State list error: {e}') + @self.action(CommonAction.PING) async def ping(data: dict[str, Any]) -> handler.ActionResponse: """Ping""" diff --git a/tests/unit_tests/agent/test_context_builder_state.py b/tests/unit_tests/agent/test_context_builder_state.py new file mode 100644 index 00000000..0fdf2a53 --- /dev/null +++ b/tests/unit_tests/agent/test_context_builder_state.py @@ -0,0 +1,361 @@ +"""Tests for ContextAccess.state determination in AgentRunContextBuilder. + +Tests focus on: +- Event-first mode: state=True when enable_state=True and state_scopes non-empty +- Event-first mode: state=False when enable_state=False +- Legacy Query mode: state=False (no persistent state API) +""" +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder +from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy +from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext +from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput +from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext + + +class MockApplication: + """Mock Application for testing.""" + def __init__(self): + self.logger = MagicMock() + self.persistence_mgr = MagicMock() + self.persistence_mgr.get_db_engine = MagicMock() + + +class TestContextAccessStateDetermination: + """Tests for ContextAccess.state field determination - real calls to _build_context_access.""" + + @pytest.fixture + def mock_app(self): + """Create mock application.""" + return MockApplication() + + @pytest.fixture + def mock_event(self): + """Create mock event envelope.""" + return AgentEventEnvelope( + event_id='evt_001', + event_type='message.received', + event_time=1234567890, + source='test', + bot_id='bot_001', + workspace_id='ws_001', + conversation_id='conv_001', + thread_id=None, + actor=ActorContext(actor_type='user', actor_id='user_001'), + subject=None, + input=AgentInput(text='hello', contents=[], attachments=[]), + delivery=DeliveryContext(surface='test', supports_streaming=True), + ) + + @pytest.fixture + def mock_descriptor(self): + """Create mock runner descriptor.""" + descriptor = MagicMock() + descriptor.id = 'plugin:test/runner/default' + descriptor.protocol_version = '1.0' + descriptor.permissions = {} + return descriptor + + @pytest.mark.asyncio + async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor): + """ContextAccess.state=True when enable_state=True and state_scopes non-empty.""" + # Create binding with state enabled and non-empty scopes + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy( + enable_state=True, + state_scopes=['conversation', 'actor'], + ), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call to _build_context_access + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # Verify state=True based on binding.state_policy + assert context_access['available_apis']['state'] is True + + @pytest.mark.asyncio + async def test_enable_state_false_sets_state_false(self, mock_app, mock_event, mock_descriptor): + """ContextAccess.state=False when enable_state=False.""" + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy( + enable_state=False, + state_scopes=[], + ), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # Verify state=False + assert context_access['available_apis']['state'] is False + + @pytest.mark.asyncio + async def test_enable_state_true_empty_scopes_sets_state_false(self, mock_app, mock_event, mock_descriptor): + """ContextAccess.state=False when enable_state=True but state_scopes empty.""" + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy( + enable_state=True, + state_scopes=[], # Empty scopes - state not available + ), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # Verify state=False (empty scopes means state not available) + assert context_access['available_apis']['state'] is False + + @pytest.mark.asyncio + async def test_no_binding_sets_state_false(self, mock_app, mock_event, mock_descriptor): + """ContextAccess.state=False when binding is None (legacy mode).""" + builder = AgentRunContextBuilder(mock_app) + + # Real call without binding + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding=None) + + # Verify state=False (no binding = no state policy = state disabled) + assert context_access['available_apis']['state'] is False + + @pytest.mark.asyncio + async def test_runner_scope_available_without_conversation(self, mock_app, mock_descriptor): + """State API with runner scope is available even without conversation_id.""" + mock_event = AgentEventEnvelope( + event_id='evt_002', + event_type='message.received', + event_time=1234567890, + source='test', + bot_id='bot_001', + workspace_id='ws_001', + conversation_id=None, # No conversation + thread_id=None, + actor=ActorContext(actor_type='user', actor_id='user_001'), + subject=None, + input=AgentInput(text='hello', contents=[], attachments=[]), + delivery=DeliveryContext(surface='test', supports_streaming=True), + ) + + binding = AgentBinding( + binding_id='binding_002', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='workspace', scope_id='ws_001'), + state_policy=StatePolicy( + enable_state=True, + state_scopes=['runner'], # Runner scope doesn't need conversation_id + ), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # State should be True because runner scope is enabled + assert context_access['available_apis']['state'] is True + + @pytest.mark.asyncio + async def test_multiple_scopes_all_available(self, mock_app, mock_event, mock_descriptor): + """State API with multiple scopes enabled.""" + binding = AgentBinding( + binding_id='binding_003', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy( + enable_state=True, + state_scopes=['conversation', 'actor', 'subject', 'runner'], + ), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # State should be True with all scopes enabled + assert context_access['available_apis']['state'] is True + + +class TestStatePolicyFromBinding: + """Tests for state_policy extraction from binding.""" + + def test_state_policy_structure(self): + """State policy has correct structure.""" + policy = StatePolicy( + enable_state=True, + state_scopes=['conversation', 'actor', 'subject', 'runner'], + ) + + assert policy.enable_state is True + assert len(policy.state_scopes) == 4 + assert 'conversation' in policy.state_scopes + + def test_state_policy_disabled(self): + """State policy can be disabled.""" + policy = StatePolicy( + enable_state=False, + state_scopes=[], + ) + + assert policy.enable_state is False + assert len(policy.state_scopes) == 0 + + +class TestBindingWithStatePolicy: + """Tests for binding with state_policy.""" + + def test_binding_contains_state_policy(self): + """Binding contains state_policy field.""" + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy( + enable_state=True, + state_scopes=['conversation'], + ), + ) + + assert binding.state_policy is not None + assert binding.state_policy.enable_state is True + + +class TestContextAccessOtherAPIs: + """Tests for other available_apis fields based on permissions.""" + + @pytest.fixture + def mock_app(self): + """Create mock application.""" + return MockApplication() + + @pytest.mark.asyncio + async def test_history_apis_based_on_permissions(self, mock_app): + """History APIs availability based on runner permissions.""" + mock_event = MagicMock() + mock_event.conversation_id = 'conv_001' + mock_event.thread_id = None + + mock_descriptor = MagicMock() + mock_descriptor.permissions = { + 'history': ['page', 'search'], + } + + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy(enable_state=False, state_scopes=[]), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # History APIs enabled based on permissions + assert context_access['available_apis']['history_page'] is True + assert context_access['available_apis']['history_search'] is True + + @pytest.mark.asyncio + async def test_event_apis_based_on_permissions(self, mock_app): + """Event APIs availability based on runner permissions.""" + mock_event = MagicMock() + mock_event.conversation_id = 'conv_001' + mock_event.thread_id = None + + mock_descriptor = MagicMock() + mock_descriptor.permissions = { + 'events': ['get', 'page'], + } + + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy(enable_state=False, state_scopes=[]), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # Event APIs enabled based on permissions + assert context_access['available_apis']['event_get'] is True + assert context_access['available_apis']['event_page'] is True + + @pytest.mark.asyncio + async def test_artifact_apis_based_on_permissions(self, mock_app): + """Artifact APIs availability based on runner permissions.""" + mock_event = MagicMock() + mock_event.conversation_id = 'conv_001' + mock_event.thread_id = None + + mock_descriptor = MagicMock() + mock_descriptor.permissions = { + 'artifacts': ['metadata', 'read'], + } + + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy(enable_state=False, state_scopes=[]), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # Artifact APIs enabled based on permissions + assert context_access['available_apis']['artifact_metadata'] is True + assert context_access['available_apis']['artifact_read'] is True + + @pytest.mark.asyncio + async def test_no_permissions_all_apis_disabled(self, mock_app): + """All pull APIs disabled when permissions are empty.""" + mock_event = MagicMock() + mock_event.conversation_id = 'conv_001' + mock_event.thread_id = None + + mock_descriptor = MagicMock() + mock_descriptor.permissions = {} # No permissions + + binding = AgentBinding( + binding_id='binding_001', + runner_id='plugin:test/runner/default', + scope=BindingScope(scope_type='pipeline', scope_id='conv_001'), + state_policy=StatePolicy(enable_state=False, state_scopes=[]), + ) + + builder = AgentRunContextBuilder(mock_app) + + # Real call + context_access = await builder._build_context_access(mock_event, mock_descriptor, binding) + + # All pull APIs should be disabled + assert context_access['available_apis']['history_page'] is False + assert context_access['available_apis']['history_search'] is False + assert context_access['available_apis']['event_get'] is False + assert context_access['available_apis']['event_page'] is False + assert context_access['available_apis']['artifact_metadata'] is False + assert context_access['available_apis']['artifact_read'] is False + assert context_access['available_apis']['state'] is False diff --git a/tests/unit_tests/agent/test_context_validation.py b/tests/unit_tests/agent/test_context_validation.py index ffb641ff..84fc354e 100644 --- a/tests/unit_tests/agent/test_context_validation.py +++ b/tests/unit_tests/agent/test_context_validation.py @@ -113,13 +113,24 @@ class TestContextValidation: resources = self._make_resources() descriptor = self._make_descriptor() - # Build context - context_dict = await builder.build_context_from_event( - event=event, - binding=binding, - descriptor=descriptor, - resources=resources, - ) + # Mock persistent state store to return empty state snapshot + with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store: + mock_store = AsyncMock() + mock_store.build_snapshot_from_event = AsyncMock(return_value={ + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + }) + mock_get_store.return_value = mock_store + + # Build context + context_dict = await builder.build_context_from_event( + event=event, + binding=binding, + descriptor=descriptor, + resources=resources, + ) # Validate it can be parsed by SDK AgentRunContext # This will raise ValidationError if invalid @@ -162,12 +173,23 @@ class TestContextValidation: resources = self._make_resources() descriptor = self._make_descriptor() - context_dict = await builder.build_context_from_event( - event=event, - binding=binding, - descriptor=descriptor, - resources=resources, - ) + # Mock persistent state store to return empty state snapshot + with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store: + mock_store = AsyncMock() + mock_store.build_snapshot_from_event = AsyncMock(return_value={ + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + }) + mock_get_store.return_value = mock_store + + context_dict = await builder.build_context_from_event( + event=event, + binding=binding, + descriptor=descriptor, + resources=resources, + ) # Protocol v1 does NOT have these as core fields assert 'messages' not in context_dict, "messages should not be top-level in Protocol v1" @@ -192,12 +214,23 @@ class TestContextValidation: resources = self._make_resources() descriptor = self._make_descriptor() - context_dict = await builder.build_context_from_event( - event=event, - binding=binding, - descriptor=descriptor, - resources=resources, - ) + # Mock persistent state store to return empty state snapshot + with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store: + mock_store = AsyncMock() + mock_store.build_snapshot_from_event = AsyncMock(return_value={ + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + }) + mock_get_store.return_value = mock_store + + context_dict = await builder.build_context_from_event( + event=event, + binding=binding, + descriptor=descriptor, + resources=resources, + ) # event is REQUIRED in Protocol v1 assert context_dict.get('event') is not None, "event is REQUIRED for Protocol v1" @@ -217,12 +250,23 @@ class TestContextValidation: resources = self._make_resources() descriptor = self._make_descriptor() - context_dict = await builder.build_context_from_event( - event=event, - binding=binding, - descriptor=descriptor, - resources=resources, - ) + # Mock persistent state store to return empty state snapshot + with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store: + mock_store = AsyncMock() + mock_store.build_snapshot_from_event = AsyncMock(return_value={ + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + }) + mock_get_store.return_value = mock_store + + context_dict = await builder.build_context_from_event( + event=event, + binding=binding, + descriptor=descriptor, + resources=resources, + ) # delivery is REQUIRED in Protocol v1 assert context_dict.get('delivery') is not None, "delivery is REQUIRED for Protocol v1" diff --git a/tests/unit_tests/agent/test_state_api_auth.py b/tests/unit_tests/agent/test_state_api_auth.py new file mode 100644 index 00000000..8f91f404 --- /dev/null +++ b/tests/unit_tests/agent/test_state_api_auth.py @@ -0,0 +1,538 @@ +"""Tests for State API handler authorization in RuntimeConnectionHandler. + +Tests focus on: +- STATE_GET authorization +- STATE_SET authorization +- STATE_DELETE authorization +- STATE_LIST authorization + +These tests instantiate real RuntimeConnectionHandler action handlers and verify: +- Authorization errors for missing/mismatched caller_plugin_identity +- Authorization errors for disabled state or scope +- Full flow: set -> get -> list -> delete with real SQLite + +Authorization rules: +- caller_plugin_identity is REQUIRED when session has plugin_identity +- caller_plugin_identity must match session's plugin_identity +- enable_state must be True +- scope must be in state_scopes +""" +from __future__ import annotations + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from sqlalchemy.ext.asyncio import create_async_engine + +from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry, get_session_registry +from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore, reset_persistent_state_store +from langbot.pkg.plugin.handler import RuntimeConnectionHandler +from langbot_plugin.runtime.io.connection import Connection +from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction + +# Import shared test fixtures +from .conftest import make_resources + + +class FakeConnection: + """Fake connection for testing.""" + pass + + +class FakeApplication: + """Fake Application for testing.""" + def __init__(self, db_engine=None): + self.logger = MagicMock() + self.logger.debug = MagicMock() + self.logger.warning = MagicMock() + self.logger.error = MagicMock() + self.persistence_mgr = MagicMock() + self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + +@pytest.fixture +def session_registry(): + """Create a fresh session registry for each test.""" + return AgentRunSessionRegistry() + + +@pytest.fixture +async def db_engine(): + """Create an in-memory SQLite database for testing.""" + engine = create_async_engine('sqlite+aiosqlite:///:memory:') + yield engine + await engine.dispose() + + +@pytest.fixture +async def persistent_store(db_engine): + """Create a persistent state store with real SQLite.""" + reset_persistent_state_store() + store = PersistentStateStore(db_engine) + + # Create the table + from langbot.pkg.entity.persistence.agent_runner_state import AgentRunnerState + from sqlalchemy import text + async with db_engine.begin() as conn: + await conn.run_sync(AgentRunnerState.__table__.create, checkfirst=True) + + yield store + reset_persistent_state_store() + + +class TestStateAPIHandlerAuthorization: + """Tests for State API handler authorization with real action calls.""" + + @pytest.mark.asyncio + async def test_state_get_missing_run_id_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: missing run_id returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + + # Get the STATE_GET action handler (actions dict is keyed by action value string) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Call without run_id + result = await state_get_handler({'scope': 'conversation', 'key': 'test_key'}) + + assert result.code != 0 + assert 'run_id is required' in result.message + + @pytest.mark.asyncio + async def test_state_get_run_not_found_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: run_id not in session registry returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Call with non-existent run_id + result = await state_get_handler({ + 'run_id': 'nonexistent_run', + 'scope': 'conversation', + 'key': 'test_key', + }) + + assert result.code != 0 + assert 'not found' in result.message.lower() + + @pytest.mark.asyncio + async def test_state_get_missing_caller_plugin_identity_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: missing caller_plugin_identity when session has plugin_identity returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + # Register session with plugin_identity + await session_registry.register( + run_id='run_test_missing_identity', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'}, + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Call without caller_plugin_identity + result = await state_get_handler({ + 'run_id': 'run_test_missing_identity', + 'scope': 'conversation', + 'key': 'test_key', + }) + + assert result.code != 0 + assert 'caller_plugin_identity is required' in result.message + + await session_registry.unregister('run_test_missing_identity') + + @pytest.mark.asyncio + async def test_state_get_caller_identity_mismatch_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: caller_plugin_identity mismatch returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + await session_registry.register( + run_id='run_test_mismatch', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'}, + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Call with wrong caller_plugin_identity + result = await state_get_handler({ + 'run_id': 'run_test_mismatch', + 'scope': 'conversation', + 'key': 'test_key', + 'caller_plugin_identity': 'other/plugin', + }) + + assert result.code != 0 + assert 'mismatch' in result.message.lower() + + await session_registry.unregister('run_test_mismatch') + + @pytest.mark.asyncio + async def test_state_get_enable_state_false_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: enable_state=False returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + await session_registry.register( + run_id='run_test_disabled', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': False, 'state_scopes': []}, + state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + result = await state_get_handler({ + 'run_id': 'run_test_disabled', + 'scope': 'conversation', + 'key': 'test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'disabled' in result.message.lower() + + await session_registry.unregister('run_test_disabled') + + @pytest.mark.asyncio + async def test_state_get_scope_not_enabled_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: scope not in state_scopes returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + await session_registry.register( + run_id='run_test_scope_disabled', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {'conversation': 'conv_key', 'actor': 'actor_key'}, 'binding_identity': 'binding_1'}, + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Request 'actor' scope which is not in state_scopes + result = await state_get_handler({ + 'run_id': 'run_test_scope_disabled', + 'scope': 'actor', + 'key': 'test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not enabled' in result.message.lower() or 'scope' in result.message.lower() + + await session_registry.unregister('run_test_scope_disabled') + + @pytest.mark.asyncio + async def test_state_get_missing_scope_key_returns_error(self, session_registry, db_engine, persistent_store): + """STATE_GET: missing scope_key in state_context returns error.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + await session_registry.register( + run_id='run_test_no_scope_key', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, # No scope_keys + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + result = await state_get_handler({ + 'run_id': 'run_test_no_scope_key', + 'scope': 'conversation', + 'key': 'test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not available' in result.message.lower() + + await session_registry.unregister('run_test_no_scope_key') + + +class TestStateAPIFullFlowWithRealDB: + """Tests for complete State API flow with real SQLite database.""" + + @pytest.mark.asyncio + async def test_state_set_get_list_delete_flow(self, session_registry, db_engine, persistent_store): + """Test complete state flow: set -> get -> list -> delete with real SQLite.""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + # Register session + await session_registry.register( + run_id='run_full_flow', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation', 'runner']}, + state_context={ + 'scope_keys': { + 'conversation': 'conv:test_runner:binding_1:conv_123', + 'runner': 'runner:test_runner:binding_1', + }, + 'binding_identity': 'binding_1', + 'conversation_id': 'conv_123', + }, + ) + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + + # Verify session has correct state_context + session = await session_registry.get('run_full_flow') + assert session is not None + state_ctx = session.get('state_context') + assert state_ctx is not None, f"state_context is None. Session keys: {list(session.keys())}" + assert 'scope_keys' in state_ctx, f"scope_keys not in state_context: {state_ctx}" + assert 'conversation' in state_ctx['scope_keys'], f"conversation not in scope_keys: {state_ctx['scope_keys']}" + + # Get handlers (actions dict is keyed by action value string) + state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value] + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + state_list_handler = handler.actions[PluginToRuntimeAction.STATE_LIST.value] + state_delete_handler = handler.actions[PluginToRuntimeAction.STATE_DELETE.value] + + # 1. STATE_SET + set_result = await state_set_handler({ + 'run_id': 'run_full_flow', + 'scope': 'conversation', + 'key': 'external.test_key', + 'value': {'data': 'test_value'}, + 'caller_plugin_identity': 'test/runner', + }) + + assert set_result.code == 0 + assert set_result.data.get('success') is True + + # 2. STATE_GET + get_result = await state_get_handler({ + 'run_id': 'run_full_flow', + 'scope': 'conversation', + 'key': 'external.test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert get_result.code == 0 + assert get_result.data.get('value') == {'data': 'test_value'} + + # 3. STATE_LIST + list_result = await state_list_handler({ + 'run_id': 'run_full_flow', + 'scope': 'conversation', + 'prefix': 'external.', + 'caller_plugin_identity': 'test/runner', + }) + + assert list_result.code == 0 + keys = list_result.data.get('keys', []) + assert 'external.test_key' in keys + + # 4. STATE_DELETE + delete_result = await state_delete_handler({ + 'run_id': 'run_full_flow', + 'scope': 'conversation', + 'key': 'external.test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert delete_result.code == 0 + + # 5. Verify deleted + get_after_delete = await state_get_handler({ + 'run_id': 'run_full_flow', + 'scope': 'conversation', + 'key': 'external.test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert get_after_delete.code == 0 + assert get_after_delete.data.get('value') is None + + await session_registry.unregister('run_full_flow') + + +class TestStateHandlerReadsFromSessionTopLevel: + """Tests verifying handlers read state_policy/state_context from session top-level, not resources.""" + + @pytest.mark.asyncio + async def test_state_handler_reads_state_policy_from_session_top_level(self, session_registry, db_engine, persistent_store): + """Handler reads state_policy from session['state_policy'], not session['resources']['state_policy'].""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + # Register with explicit state_policy at top level + await session_registry.register( + run_id='run_policy_top_level', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': False, 'state_scopes': []}, # Disabled at top level + state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, + ) + + # Verify resources does NOT contain state_policy + session = await session_registry.get('run_policy_top_level') + assert session is not None + assert 'state_policy' not in session.get('resources', {}), \ + "resources should NOT contain state_policy" + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value] + + # Should fail because enable_state=False in session['state_policy'] + result = await state_get_handler({ + 'run_id': 'run_policy_top_level', + 'scope': 'conversation', + 'key': 'test_key', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'disabled' in result.message.lower() + + await session_registry.unregister('run_policy_top_level') + + @pytest.mark.asyncio + async def test_state_handler_reads_state_context_from_session_top_level(self, session_registry, db_engine, persistent_store): + """Handler reads state_context from session['state_context'], not session['resources']['state_context'].""" + fake_app = FakeApplication(db_engine) + fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + # Register with explicit state_context at top level + await session_registry.register( + run_id='run_context_top_level', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=make_resources(), + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {'conversation': 'conv_key_xyz'}, 'binding_identity': 'binding_xyz'}, + ) + + # Verify resources does NOT contain state_context + session = await session_registry.get('run_context_top_level') + assert session is not None + assert 'state_context' not in session.get('resources', {}), \ + "resources should NOT contain state_context" + + async def fake_disconnect(): + return True + + with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry): + handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value] + + # Should use scope_key from session['state_context']['scope_keys']['conversation'] + result = await state_set_handler({ + 'run_id': 'run_context_top_level', + 'scope': 'conversation', + 'key': 'test_key', + 'value': 'test_value', + 'caller_plugin_identity': 'test/runner', + }) + + # Should succeed - scope_key was found in state_context + assert result.code == 0 + + await session_registry.unregister('run_context_top_level') + + +class TestResourcesDoesNotContainStateMetadata: + """Tests verifying resources is clean - no state metadata mixed in.""" + + @pytest.mark.asyncio + async def test_resources_clean_after_register(self, session_registry): + """After register(), resources should not contain state_policy or state_context.""" + resources = make_resources() + + await session_registry.register( + run_id='run_resources_clean', + runner_id='plugin:test/runner/default', + query_id=1, + plugin_identity='test/runner', + resources=resources, + state_policy={'enable_state': True, 'state_scopes': ['conversation']}, + state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'}, + ) + + session = await session_registry.get('run_resources_clean') + assert session is not None + + # Verify resources is clean + session_resources = session.get('resources', {}) + assert 'state_policy' not in session_resources, \ + "session['resources'] should NOT contain state_policy" + assert 'state_context' not in session_resources, \ + "session['resources'] should NOT contain state_context" + + # Verify state metadata is at top level + assert 'state_policy' in session + assert 'state_context' in session + + await session_registry.unregister('run_resources_clean') diff --git a/tests/unit_tests/agent/test_state_store.py b/tests/unit_tests/agent/test_state_store.py index af85a9a6..bb9a1020 100644 --- a/tests/unit_tests/agent/test_state_store.py +++ b/tests/unit_tests/agent/test_state_store.py @@ -1137,4 +1137,238 @@ class TestStateStorePolicyEnforcement: ) assert result is False - assert any('not enabled' in w for w in logger.warnings) \ No newline at end of file + assert any('not enabled' in w for w in logger.warnings) + + +# ========== Persistent State Store Tests ========== + + +import pytest +import asyncio +import tempfile +import os +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + + +class TestPersistentStateStore: + """Tests for persistent database-backed state store.""" + + @pytest.fixture + async def db_engine(self): + """Create a temporary async SQLite database for testing.""" + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: + db_path = f.name + + engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False) + + # Create tables + from langbot.pkg.entity.persistence.base import Base + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + # Cleanup + await engine.dispose() + os.unlink(db_path) + + @pytest.fixture + async def persistent_store(self, db_engine): + """Create a persistent state store for testing.""" + from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore + store = PersistentStateStore(db_engine) + yield store + await store.clear_all() + + @pytest.mark.asyncio + async def test_build_snapshot_empty(self, persistent_store): + """Building snapshot from empty store returns empty scopes.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) + + assert snapshot['conversation'] == {'external.conversation_id': 'conv_001'} + assert snapshot['actor'] == {} + assert snapshot['subject'] == {} + assert snapshot['runner'] == {} + + @pytest.mark.asyncio + async def test_state_set_and_get(self, persistent_store): + """State set/get round trip.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + # Set state + success, error = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'test_key', {'nested': 'value'}, None + ) + assert success is True + assert error is None + + # Get via snapshot + snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) + assert snapshot['conversation']['test_key'] == {'nested': 'value'} + + @pytest.mark.asyncio + async def test_binding_isolation(self, persistent_store): + """Different binding_id should have isolated state.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding_a = FakeBinding(binding_id='binding_a') + binding_b = FakeBinding(binding_id='binding_b') + + # Set for binding_a + await persistent_store.apply_update_from_event( + event, binding_a, descriptor, 'conversation', 'key', 'value_a', None + ) + + # binding_b should not see binding_a's state + snapshot_b = await persistent_store.build_snapshot_from_event(event, binding_b, descriptor) + assert snapshot_b['conversation'] == {'external.conversation_id': 'conv_001'} + + # binding_a should see its own state + snapshot_a = await persistent_store.build_snapshot_from_event(event, binding_a, descriptor) + assert snapshot_a['conversation']['key'] == 'value_a' + + @pytest.mark.asyncio + async def test_policy_disable_state(self, persistent_store): + """enable_state=False should return empty snapshot and reject updates.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + policy = StatePolicy(enable_state=False) + binding = FakeBinding(state_policy=policy) + + # Snapshot should be empty + snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) + assert snapshot == {'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}} + + # Update should be rejected + success, error = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'key', 'value', None + ) + assert success is False + assert 'disabled' in error.lower() + + @pytest.mark.asyncio + async def test_policy_scope_restriction(self, persistent_store): + """state_scopes should restrict which scopes are accessible.""" + descriptor = make_descriptor() + event = FakeEventEnvelope( + conversation_id='conv_001', + actor=FakeActorContext(actor_id='user_001'), + ) + policy = StatePolicy(state_scopes=['conversation']) # Only conversation + binding = FakeBinding(state_policy=policy) + + # Conversation should work + success_conv, _ = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'key', 'value_conv', None + ) + assert success_conv is True + + # Actor should be rejected + success_actor, error_actor = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'actor', 'key', 'value_actor', None + ) + assert success_actor is False + assert 'not enabled' in error_actor.lower() + + @pytest.mark.asyncio + async def test_value_json_size_limit(self, persistent_store): + """Value exceeding size limit should be rejected.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + # Create a large value (> 256KB) + large_value = 'x' * (300 * 1024) + + success, error = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'key', large_value, None + ) + assert success is False + assert 'exceeds limit' in error.lower() + + @pytest.mark.asyncio + async def test_value_not_json_serializable(self, persistent_store): + """Non-JSON-serializable value should be rejected.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + # Create a non-serializable value (set is not JSON-serializable) + non_serializable = {'key': {1, 2, 3}} + + success, error = await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'key', non_serializable, None + ) + assert success is False + assert 'json' in error.lower() + + @pytest.mark.asyncio + async def test_state_list(self, persistent_store): + """State list should return keys with optional prefix filter.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + # Set multiple keys + await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'external.id', '123', None + ) + await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'external.name', 'test', None + ) + await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'memory.key', 'value', None + ) + + # Build scope key for list + from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore + temp_store = PersistentStateStore(None) + scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor) + + # List all keys + keys, has_more = await persistent_store.state_list(scope_key) + assert len(keys) == 3 + assert has_more is False + + # List with prefix + keys_ext, _ = await persistent_store.state_list(scope_key, prefix='external.') + assert len(keys_ext) == 2 + assert 'external.id' in keys_ext + assert 'external.name' in keys_ext + + @pytest.mark.asyncio + async def test_state_delete(self, persistent_store): + """State delete should remove key.""" + descriptor = make_descriptor() + event = FakeEventEnvelope(conversation_id='conv_001') + binding = FakeBinding() + + # Set and verify + await persistent_store.apply_update_from_event( + event, binding, descriptor, 'conversation', 'key', 'value', None + ) + snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) + assert snapshot['conversation']['key'] == 'value' + + # Build scope key for delete + from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore + temp_store = PersistentStateStore(None) + scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor) + + # Delete + deleted = await persistent_store.state_delete(scope_key, 'key') + assert deleted is True + + # Verify deleted + snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) + assert 'key' not in snapshot['conversation'] + + # Delete non-existent should return False + deleted_again = await persistent_store.state_delete(scope_key, 'key') + assert deleted_again is False \ No newline at end of file