"""Persistent state store for AgentRunner protocol state. This module provides a database-backed state store for event-first Protocol v1. """ from __future__ import annotations import typing import json import threading from datetime import datetime import sqlalchemy from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy import select, delete, update from .descriptor import AgentRunnerDescriptor from .host_models import AgentEventEnvelope, AgentBinding from .state_scope import ( VALID_STATE_SCOPES, build_state_scope_key, get_binding_identity, normalize_state_key, ) from ...entity.persistence.agent_runner_state import AgentRunnerState # Maximum value_json size (256KB) MAX_VALUE_JSON_BYTES = 256 * 1024 class PersistentStateStore: """Database-backed state store for AgentRunner protocol state. IMPORTANT: This is HOST-OWNED protocol state, NOT plugin instance state. This store provides: 1. Persistent storage across runs via database 2. Scope isolation by runner_id + binding_identity + scope 3. Policy enforcement (enable_state, state_scopes) 4. JSON value validation and size limits Used by: - Event-first Protocol v1 (async methods) - State API handlers (get/set/delete/list) """ def __init__(self, db_engine: AsyncEngine): self._db_engine = db_engine def _get_scope_key( self, scope: str, event: AgentEventEnvelope, binding: AgentBinding, descriptor: AgentRunnerDescriptor, ) -> str | None: """Get scope key for given scope.""" return build_state_scope_key(scope, event, binding, descriptor) def _check_scope_enabled(self, scope: str, binding: AgentBinding) -> bool: """Check if scope is enabled by binding's state_policy.""" state_policy = binding.state_policy if not state_policy.enable_state: return False return scope in state_policy.state_scopes def _validate_json_value( self, value: typing.Any, logger: typing.Any = None, ) -> tuple[str | None, str | None]: """Validate and serialize value to JSON. Returns: Tuple of (json_string, error_message). If error_message is not None, json_string will be None. """ try: json_str = json.dumps(value, ensure_ascii=False) except (TypeError, ValueError) as e: return None, f'Value is not JSON-serializable: {e}' # Check size limit json_bytes = len(json_str.encode('utf-8')) if json_bytes > MAX_VALUE_JSON_BYTES: return None, f'Value size {json_bytes} bytes exceeds limit {MAX_VALUE_JSON_BYTES} bytes' return json_str, None # ========== Async DB Operations ========== async def build_snapshot_from_event( self, event: AgentEventEnvelope, binding: AgentBinding, descriptor: AgentRunnerDescriptor, ) -> dict[str, dict[str, typing.Any]]: """Build state snapshot for all scopes from event and binding. Reads from database, respects state_policy. """ state_policy = binding.state_policy # If state is disabled, return all empty scopes if not state_policy.enable_state: return { 'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}, } snapshot: dict[str, dict[str, typing.Any]] = { 'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}, } async with self._db_engine.connect() as conn: for scope in VALID_STATE_SCOPES: if not self._check_scope_enabled(scope, binding): continue scope_key = self._get_scope_key(scope, event, binding, descriptor) if not scope_key: continue # Query all state entries for this scope_key result = await conn.execute( select(AgentRunnerState.state_key, AgentRunnerState.value_json) .where(AgentRunnerState.scope_key == scope_key) ) rows = result.fetchall() for row in rows: key = row.state_key value_json = row.value_json if value_json: try: snapshot[scope][key] = json.loads(value_json) except json.JSONDecodeError: pass # Skip invalid JSON # Seed external.conversation_id from event.conversation_id if not set if self._check_scope_enabled('conversation', binding) and event.conversation_id: if 'external.conversation_id' not in snapshot['conversation']: snapshot['conversation']['external.conversation_id'] = event.conversation_id return snapshot async def apply_update_from_event( self, event: AgentEventEnvelope, binding: AgentBinding, descriptor: AgentRunnerDescriptor, scope: str, key: str, value: typing.Any, logger: typing.Any = None, ) -> tuple[bool, str | None]: """Apply a state update from event context. Returns: Tuple of (success, error_message). If success is False, error_message contains the reason. """ state_policy = binding.state_policy # Check if state is disabled if not state_policy.enable_state: return False, 'State is disabled by binding policy' # Validate scope if scope not in VALID_STATE_SCOPES: return False, f'Invalid scope: {scope}' # Check if scope is enabled if not self._check_scope_enabled(scope, binding): return False, f'Scope "{scope}" not enabled by binding policy' # Map accepted key aliases key = normalize_state_key(key) # Get scope key scope_key = self._get_scope_key(scope, event, binding, descriptor) if not scope_key: return False, f'Missing identity for scope "{scope}"' # Validate and serialize value value_json, error = self._validate_json_value(value, logger) if error: return False, error # Build context fields binding_identity = get_binding_identity(binding) async with self._db_engine.begin() as conn: # Check if entry exists result = await conn.execute( select(AgentRunnerState.id) .where(AgentRunnerState.scope_key == scope_key) .where(AgentRunnerState.state_key == key) ) existing = result.first() now = datetime.utcnow() if existing: # Update existing entry await conn.execute( update(AgentRunnerState) .where(AgentRunnerState.id == existing.id) .values( value_json=value_json, updated_at=now, ) ) else: # Insert new entry await conn.execute( sqlalchemy.insert(AgentRunnerState).values( runner_id=descriptor.id, binding_identity=binding_identity, scope=scope, scope_key=scope_key, state_key=key, value_json=value_json, bot_id=event.bot_id, workspace_id=event.workspace_id, conversation_id=event.conversation_id, thread_id=event.thread_id, actor_type=event.actor.actor_type if event.actor else None, actor_id=event.actor.actor_id if event.actor else None, subject_type=event.subject.subject_type if event.subject else None, subject_id=event.subject.subject_id if event.subject else None, created_at=now, updated_at=now, ) ) return True, None async def state_get( self, scope_key: str, state_key: str, ) -> typing.Any: """Get a single state value by scope_key and state_key. Used by State API handlers. """ state_key = normalize_state_key(state_key) async with self._db_engine.connect() as conn: result = await conn.execute( select(AgentRunnerState.value_json) .where(AgentRunnerState.scope_key == scope_key) .where(AgentRunnerState.state_key == state_key) ) row = result.first() if not row or not row.value_json: return None try: return json.loads(row.value_json) except json.JSONDecodeError: return None async def state_set( self, scope_key: str, state_key: str, value: typing.Any, runner_id: str, binding_identity: str, scope: str, context: dict[str, typing.Any] | None = None, logger: typing.Any = None, ) -> tuple[bool, str | None]: """Set a state value. Used by State API handlers. Context contains optional fields like bot_id, conversation_id, etc. """ state_key = normalize_state_key(state_key) # Validate and serialize value value_json, error = self._validate_json_value(value, logger) if error: return False, error context = context or {} async with self._db_engine.begin() as conn: # Check if entry exists result = await conn.execute( select(AgentRunnerState.id) .where(AgentRunnerState.scope_key == scope_key) .where(AgentRunnerState.state_key == state_key) ) existing = result.first() now = datetime.utcnow() if existing: # Update existing entry await conn.execute( update(AgentRunnerState) .where(AgentRunnerState.id == existing.id) .values( value_json=value_json, updated_at=now, ) ) else: # Insert new entry await conn.execute( sqlalchemy.insert(AgentRunnerState).values( runner_id=runner_id, binding_identity=binding_identity, scope=scope, scope_key=scope_key, state_key=state_key, value_json=value_json, bot_id=context.get('bot_id'), workspace_id=context.get('workspace_id'), conversation_id=context.get('conversation_id'), thread_id=context.get('thread_id'), actor_type=context.get('actor_type'), actor_id=context.get('actor_id'), subject_type=context.get('subject_type'), subject_id=context.get('subject_id'), created_at=now, updated_at=now, ) ) return True, None async def state_delete( self, scope_key: str, state_key: str, ) -> bool: """Delete a state value. Returns True if deleted, False if not found. """ state_key = normalize_state_key(state_key) async with self._db_engine.begin() as conn: result = await conn.execute( delete(AgentRunnerState) .where(AgentRunnerState.scope_key == scope_key) .where(AgentRunnerState.state_key == state_key) .returning(AgentRunnerState.id) ) deleted = result.first() return deleted is not None async def state_list( self, scope_key: str, prefix: str | None = None, limit: int = 100, ) -> tuple[list[str], bool]: """List state keys in a scope. Returns tuple of (keys, has_more). """ # Enforce limit cap limit = min(limit, 100) async with self._db_engine.connect() as conn: query = ( select(AgentRunnerState.state_key) .where(AgentRunnerState.scope_key == scope_key) .order_by(AgentRunnerState.state_key) .limit(limit + 1) # Fetch one extra to check has_more ) if prefix: prefix = normalize_state_key(prefix) query = query.where( AgentRunnerState.state_key.like(f'{prefix}%') ) result = await conn.execute(query) rows = result.fetchall() keys = [row.state_key for row in rows[:limit]] has_more = len(rows) > limit return keys, has_more async def clear_all(self) -> None: """Clear all state entries (for testing).""" async with self._db_engine.begin() as conn: await conn.execute(delete(AgentRunnerState)) # Global singleton persistent state store _persistent_state_store: PersistentStateStore | None = None _persistent_state_store_lock = threading.Lock() def get_persistent_state_store(db_engine: AsyncEngine | None = None) -> PersistentStateStore: """Get the global persistent state store singleton. Args: db_engine: Database engine (required on first call) Returns: PersistentStateStore singleton """ global _persistent_state_store with _persistent_state_store_lock: if _persistent_state_store is None: if db_engine is None: raise RuntimeError("db_engine required for first call to get_persistent_state_store") _persistent_state_store = PersistentStateStore(db_engine) return _persistent_state_store def reset_persistent_state_store() -> None: """Reset the global persistent state store (for testing).""" global _persistent_state_store with _persistent_state_store_lock: _persistent_state_store = None