feat(agent-runner): add persistent state APIs

This commit is contained in:
huanghuoguoguo
2026-05-23 21:45:11 +08:00
parent 4e68a93df7
commit ce007c49c8
12 changed files with 2407 additions and 62 deletions

View File

@@ -13,6 +13,7 @@ from .descriptor import AgentRunnerDescriptor
from .config_migration import ConfigMigration
from .context_packager import AgentContextPackager
from .state_store import get_state_store
from .persistent_state_store import get_persistent_state_store
from . import events as runner_events
from .host_models import AgentEventEnvelope, AgentBinding
from .pipeline_compat_adapter import PipelineCompatAdapter
@@ -259,11 +260,13 @@ class AgentRunContextBuilder:
# Build context access (no history inlined by default for Protocol v1)
# Populate with actual values from stores
context_access = await self._build_context_access(event, descriptor)
context_access = await self._build_context_access(event, descriptor, binding)
# Build state snapshot from event context
state_store = get_state_store()
state: AgentRunState = state_store.build_snapshot_from_event(event, binding, descriptor)
# Build state snapshot from persistent state store (event-first Protocol v1)
persistent_state_store = get_persistent_state_store(
self.ap.persistence_mgr.get_db_engine()
)
state: AgentRunState = await persistent_state_store.build_snapshot_from_event(event, binding, descriptor)
# Build runtime context
runtime: AgentRuntimeContext = {
@@ -420,6 +423,7 @@ class AgentRunContextBuilder:
}
# Build context access (for legacy, minimal API availability)
# Legacy Query-based mode does NOT have persistent state API
context_access = {
'conversation_id': conversation.get('conversation_id') if conversation else None,
'thread_id': None,
@@ -441,7 +445,7 @@ class AgentRunContextBuilder:
'event_page': False,
'artifact_metadata': False,
'artifact_read': False,
'state': True,
'state': False, # Legacy Query mode does not have persistent state API
'storage': True,
},
}
@@ -869,12 +873,14 @@ class AgentRunContextBuilder:
self,
event: AgentEventEnvelope,
descriptor: AgentRunnerDescriptor,
binding: AgentBinding | None = None,
) -> dict[str, typing.Any]:
"""Build ContextAccess with actual values from stores.
Args:
event: Event envelope
descriptor: Runner descriptor
binding: Agent binding (required for state_policy in event-first mode)
Returns:
ContextAccess dict
@@ -895,6 +901,14 @@ class AgentRunContextBuilder:
artifact_metadata_enabled = 'metadata' in artifact_permissions
artifact_read_enabled = 'read' in artifact_permissions
# Determine state API availability based on binding state_policy (event-first mode)
# For legacy Query-based mode, state is NOT available (no persistent state API)
state_enabled = False
if binding is not None:
state_policy = binding.state_policy
if state_policy.enable_state and state_policy.state_scopes:
state_enabled = True
# Get latest cursor and has_history_before if conversation exists
latest_cursor = None
has_history_before = False
@@ -931,7 +945,7 @@ class AgentRunContextBuilder:
'event_page': event_page_enabled,
'artifact_metadata': artifact_metadata_enabled,
'artifact_read': artifact_read_enabled,
'state': True,
'state': state_enabled,
'storage': True,
},
}

View File

@@ -17,6 +17,7 @@ from .context_builder import AgentRunContextBuilder, AgentRunContextPayload
from .resource_builder import AgentResourceBuilder
from .result_normalizer import AgentResultNormalizer
from .state_store import get_state_store, RunnerScopedStateStore
from .persistent_state_store import get_persistent_state_store, PersistentStateStore
from .session_registry import get_session_registry, AgentRunSessionRegistry
from .config_migration import ConfigMigration
from .host_models import AgentEventEnvelope, AgentBinding
@@ -63,6 +64,7 @@ class AgentRunOrchestrator:
# Cached singleton references (set in __init__)
_session_registry: AgentRunSessionRegistry
_state_store: RunnerScopedStateStore
_persistent_state_store: PersistentStateStore | None
def __init__(
self,
@@ -77,6 +79,7 @@ class AgentRunOrchestrator:
# Cache singleton references to avoid per-request getter calls
self._session_registry = get_session_registry()
self._state_store = get_state_store()
self._persistent_state_store = None # Lazy init on first use
async def run(
self,
@@ -122,6 +125,9 @@ class AgentRunOrchestrator:
resources=resources,
)
# Build state context for State API handlers
state_context = self._build_state_context(event, binding, descriptor)
# Register session for proxy action permission validation
run_id = context['run_id']
await self._session_registry.register(
@@ -132,6 +138,11 @@ class AgentRunOrchestrator:
resources=resources,
permissions=descriptor.permissions or {},
conversation_id=event.conversation_id,
state_policy={
'enable_state': binding.state_policy.enable_state,
'state_scopes': list(binding.state_policy.state_scopes),
},
state_context=state_context,
)
# Write incoming event to EventLog
@@ -170,7 +181,7 @@ class AgentRunOrchestrator:
# Handle state.updated first - consume before normalizer
if result_dict.get('type') == 'state.updated':
self._handle_state_updated_event(result_dict, event, binding, descriptor)
await self._handle_state_updated_event(result_dict, event, binding, descriptor)
# Pass to normalizer for logging, but don't yield to pipeline
await self.result_normalizer.normalize(result_dict, descriptor)
continue
@@ -555,7 +566,7 @@ class AgentRunOrchestrator:
f'artifact.created failed to register artifact: {e}',
)
def _handle_state_updated_event(
async def _handle_state_updated_event(
self,
result_dict: dict[str, typing.Any],
event: AgentEventEnvelope,
@@ -564,6 +575,8 @@ class AgentRunOrchestrator:
) -> None:
"""Handle state.updated result in event-first mode.
Persists state to database via PersistentStateStore.
Args:
result_dict: Raw result dict with type='state.updated'
event: Event envelope
@@ -585,8 +598,14 @@ class AgentRunOrchestrator:
)
return
# Apply update to state store using event context
success = self._state_store.apply_update_from_event(
# Lazy init persistent state store
if self._persistent_state_store is None:
self._persistent_state_store = get_persistent_state_store(
self.ap.persistence_mgr.get_db_engine()
)
# Apply update to persistent state store
success, error = await self._persistent_state_store.apply_update_from_event(
event=event,
binding=binding,
descriptor=descriptor,
@@ -600,7 +619,79 @@ class AgentRunOrchestrator:
self.ap.logger.debug(
f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}'
)
# Invalid scope or missing identity is already logged by apply_update_from_event
elif error:
self.ap.logger.warning(
f'Runner {descriptor.id} state.updated rejected: {error}'
)
def _build_state_context(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> dict[str, typing.Any]:
"""Build state context for State API handlers.
Returns context with:
- scope_keys: Dict mapping scope name to scope_key
- binding_identity: Binding identity for state isolation
- Additional context fields for DB insert
"""
# Get binding identity
binding_identity = binding.binding_id
if not binding_identity:
scope = binding.scope
if scope.scope_type and scope.scope_id:
binding_identity = f"{scope.scope_type}:{scope.scope_id}"
else:
binding_identity = "unknown_binding"
# Build scope keys for each scope
scope_keys: dict[str, str] = {}
# Conversation scope
if event.conversation_id:
parts = [descriptor.id, binding_identity, event.conversation_id]
if event.thread_id:
parts.append(event.thread_id)
scope_keys['conversation'] = f'conversation:{":".join(parts)}'
# Actor scope
if event.actor and event.actor.actor_id:
parts = [
descriptor.id,
binding_identity,
event.actor.actor_type or 'user',
event.actor.actor_id,
]
scope_keys['actor'] = f'actor:{":".join(parts)}'
# Subject scope
if event.subject and event.subject.subject_id:
parts = [
descriptor.id,
binding_identity,
event.subject.subject_type or 'unknown',
event.subject.subject_id,
]
scope_keys['subject'] = f'subject:{":".join(parts)}'
# Runner scope (always available)
parts = [descriptor.id, binding_identity]
scope_keys['runner'] = f'runner:{":".join(parts)}'
return {
'scope_keys': scope_keys,
'binding_identity': binding_identity,
'bot_id': event.bot_id,
'workspace_id': event.workspace_id,
'conversation_id': event.conversation_id,
'thread_id': event.thread_id,
'actor_type': event.actor.actor_type if event.actor else None,
'actor_id': event.actor.actor_id if event.actor else None,
'subject_type': event.subject.subject_type if event.subject else None,
'subject_id': event.subject.subject_id if event.subject else None,
}
async def _write_event_log(
self,

View File

@@ -0,0 +1,522 @@
"""Persistent state store for AgentRunner protocol state.
This module provides a database-backed state store for event-first Protocol v1,
while preserving in-memory state store for legacy Query-based flow.
"""
from __future__ import annotations
import typing
import json
import asyncio
import threading
from datetime import datetime
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy import select, delete, update
from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query
from .descriptor import AgentRunnerDescriptor
from .host_models import AgentEventEnvelope, AgentBinding
from ...entity.persistence.agent_runner_state import AgentRunnerState
# Valid state scopes for agent runner state updates.
VALID_STATE_SCOPES = ('conversation', 'actor', 'subject', 'runner')
# Key mapping for backward compatibility
LEGACY_KEY_MAPPING = {
'conversation_id': 'external.conversation_id',
}
# Maximum value_json size (256KB)
MAX_VALUE_JSON_BYTES = 256 * 1024
class PersistentStateStore:
"""Database-backed state store for AgentRunner protocol state.
IMPORTANT: This is HOST-OWNED protocol state, NOT plugin instance state.
This store provides:
1. Persistent storage across runs via database
2. Scope isolation by runner_id + binding_identity + scope
3. Policy enforcement (enable_state, state_scopes)
4. JSON value validation and size limits
Used by:
- Event-first Protocol v1 (async methods)
- State API handlers (get/set/delete/list)
"""
def __init__(self, db_engine: AsyncEngine):
self._db_engine = db_engine
# ========== Scope Key Building (shared with in-memory store) ==========
def _get_binding_identity(self, binding: AgentBinding) -> str:
"""Get stable binding identity for scope key."""
if binding.binding_id:
return binding.binding_id
scope = binding.scope
if scope.scope_type and scope.scope_id:
return f"{scope.scope_type}:{scope.scope_id}"
return "unknown_binding"
def _make_conversation_scope_key(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> str | None:
"""Build conversation scope key from event and binding."""
if not event.conversation_id:
return None
binding_identity = self._get_binding_identity(binding)
parts = [
descriptor.id,
binding_identity,
event.conversation_id,
]
if event.thread_id:
parts.append(event.thread_id)
return f'conversation:{":".join(parts)}'
def _make_actor_scope_key(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> str | None:
"""Build actor scope key from event and binding."""
if not event.actor or not event.actor.actor_id:
return None
binding_identity = self._get_binding_identity(binding)
parts = [
descriptor.id,
binding_identity,
event.actor.actor_type or 'user',
event.actor.actor_id,
]
return f'actor:{":".join(parts)}'
def _make_subject_scope_key(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> str | None:
"""Build subject scope key from event and binding."""
if not event.subject or not event.subject.subject_id:
return None
binding_identity = self._get_binding_identity(binding)
parts = [
descriptor.id,
binding_identity,
event.subject.subject_type or 'unknown',
event.subject.subject_id,
]
return f'subject:{":".join(parts)}'
def _make_runner_scope_key(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> str:
"""Build runner scope key from event and binding."""
binding_identity = self._get_binding_identity(binding)
parts = [
descriptor.id,
binding_identity,
]
return f'runner:{":".join(parts)}'
def _get_scope_key(
self,
scope: str,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> str | None:
"""Get scope key for given scope."""
if scope == 'conversation':
return self._make_conversation_scope_key(event, binding, descriptor)
elif scope == 'actor':
return self._make_actor_scope_key(event, binding, descriptor)
elif scope == 'subject':
return self._make_subject_scope_key(event, binding, descriptor)
elif scope == 'runner':
return self._make_runner_scope_key(event, binding, descriptor)
return None
def _check_scope_enabled(self, scope: str, binding: AgentBinding) -> bool:
"""Check if scope is enabled by binding's state_policy."""
state_policy = binding.state_policy
if not state_policy.enable_state:
return False
return scope in state_policy.state_scopes
def _validate_json_value(
self,
value: typing.Any,
logger: typing.Any = None,
) -> tuple[str | None, str | None]:
"""Validate and serialize value to JSON.
Returns:
Tuple of (json_string, error_message). If error_message is not None,
json_string will be None.
"""
try:
json_str = json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError) as e:
return None, f'Value is not JSON-serializable: {e}'
# Check size limit
json_bytes = len(json_str.encode('utf-8'))
if json_bytes > MAX_VALUE_JSON_BYTES:
return None, f'Value size {json_bytes} bytes exceeds limit {MAX_VALUE_JSON_BYTES} bytes'
return json_str, None
# ========== Async DB Operations ==========
async def build_snapshot_from_event(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
) -> dict[str, dict[str, typing.Any]]:
"""Build state snapshot for all scopes from event and binding.
Reads from database, respects state_policy.
"""
state_policy = binding.state_policy
# If state is disabled, return all empty scopes
if not state_policy.enable_state:
return {
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
}
snapshot: dict[str, dict[str, typing.Any]] = {
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
}
async with self._db_engine.connect() as conn:
for scope in VALID_STATE_SCOPES:
if not self._check_scope_enabled(scope, binding):
continue
scope_key = self._get_scope_key(scope, event, binding, descriptor)
if not scope_key:
continue
# Query all state entries for this scope_key
result = await conn.execute(
select(AgentRunnerState.state_key, AgentRunnerState.value_json)
.where(AgentRunnerState.scope_key == scope_key)
)
rows = result.fetchall()
for row in rows:
key = row.state_key
value_json = row.value_json
if value_json:
try:
snapshot[scope][key] = json.loads(value_json)
except json.JSONDecodeError:
pass # Skip invalid JSON
# Seed external.conversation_id from event.conversation_id if not set
if self._check_scope_enabled('conversation', binding) and event.conversation_id:
if 'external.conversation_id' not in snapshot['conversation']:
snapshot['conversation']['external.conversation_id'] = event.conversation_id
return snapshot
async def apply_update_from_event(
self,
event: AgentEventEnvelope,
binding: AgentBinding,
descriptor: AgentRunnerDescriptor,
scope: str,
key: str,
value: typing.Any,
logger: typing.Any = None,
) -> tuple[bool, str | None]:
"""Apply a state update from event context.
Returns:
Tuple of (success, error_message). If success is False, error_message
contains the reason.
"""
state_policy = binding.state_policy
# Check if state is disabled
if not state_policy.enable_state:
return False, 'State is disabled by binding policy'
# Validate scope
if scope not in VALID_STATE_SCOPES:
return False, f'Invalid scope: {scope}'
# Check if scope is enabled
if not self._check_scope_enabled(scope, binding):
return False, f'Scope "{scope}" not enabled by binding policy'
# Map legacy key names
if key in LEGACY_KEY_MAPPING:
key = LEGACY_KEY_MAPPING[key]
# Get scope key
scope_key = self._get_scope_key(scope, event, binding, descriptor)
if not scope_key:
return False, f'Missing identity for scope "{scope}"'
# Validate and serialize value
value_json, error = self._validate_json_value(value, logger)
if error:
return False, error
# Build context fields
binding_identity = self._get_binding_identity(binding)
async with self._db_engine.begin() as conn:
# Check if entry exists
result = await conn.execute(
select(AgentRunnerState.id)
.where(AgentRunnerState.scope_key == scope_key)
.where(AgentRunnerState.state_key == key)
)
existing = result.first()
now = datetime.utcnow()
if existing:
# Update existing entry
await conn.execute(
update(AgentRunnerState)
.where(AgentRunnerState.id == existing.id)
.values(
value_json=value_json,
updated_at=now,
)
)
else:
# Insert new entry
await conn.execute(
sqlalchemy.insert(AgentRunnerState).values(
runner_id=descriptor.id,
binding_identity=binding_identity,
scope=scope,
scope_key=scope_key,
state_key=key,
value_json=value_json,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
conversation_id=event.conversation_id,
thread_id=event.thread_id,
actor_type=event.actor.actor_type if event.actor else None,
actor_id=event.actor.actor_id if event.actor else None,
subject_type=event.subject.subject_type if event.subject else None,
subject_id=event.subject.subject_id if event.subject else None,
created_at=now,
updated_at=now,
)
)
return True, None
async def state_get(
self,
scope_key: str,
state_key: str,
) -> typing.Any:
"""Get a single state value by scope_key and state_key.
Used by State API handlers.
"""
async with self._db_engine.connect() as conn:
result = await conn.execute(
select(AgentRunnerState.value_json)
.where(AgentRunnerState.scope_key == scope_key)
.where(AgentRunnerState.state_key == state_key)
)
row = result.first()
if not row or not row.value_json:
return None
try:
return json.loads(row.value_json)
except json.JSONDecodeError:
return None
async def state_set(
self,
scope_key: str,
state_key: str,
value: typing.Any,
runner_id: str,
binding_identity: str,
scope: str,
context: dict[str, typing.Any] | None = None,
logger: typing.Any = None,
) -> tuple[bool, str | None]:
"""Set a state value.
Used by State API handlers.
Context contains optional fields like bot_id, conversation_id, etc.
"""
# Validate and serialize value
value_json, error = self._validate_json_value(value, logger)
if error:
return False, error
context = context or {}
async with self._db_engine.begin() as conn:
# Check if entry exists
result = await conn.execute(
select(AgentRunnerState.id)
.where(AgentRunnerState.scope_key == scope_key)
.where(AgentRunnerState.state_key == state_key)
)
existing = result.first()
now = datetime.utcnow()
if existing:
# Update existing entry
await conn.execute(
update(AgentRunnerState)
.where(AgentRunnerState.id == existing.id)
.values(
value_json=value_json,
updated_at=now,
)
)
else:
# Insert new entry
await conn.execute(
sqlalchemy.insert(AgentRunnerState).values(
runner_id=runner_id,
binding_identity=binding_identity,
scope=scope,
scope_key=scope_key,
state_key=state_key,
value_json=value_json,
bot_id=context.get('bot_id'),
workspace_id=context.get('workspace_id'),
conversation_id=context.get('conversation_id'),
thread_id=context.get('thread_id'),
actor_type=context.get('actor_type'),
actor_id=context.get('actor_id'),
subject_type=context.get('subject_type'),
subject_id=context.get('subject_id'),
created_at=now,
updated_at=now,
)
)
return True, None
async def state_delete(
self,
scope_key: str,
state_key: str,
) -> bool:
"""Delete a state value.
Returns True if deleted, False if not found.
"""
async with self._db_engine.begin() as conn:
result = await conn.execute(
delete(AgentRunnerState)
.where(AgentRunnerState.scope_key == scope_key)
.where(AgentRunnerState.state_key == state_key)
.returning(AgentRunnerState.id)
)
deleted = result.first()
return deleted is not None
async def state_list(
self,
scope_key: str,
prefix: str | None = None,
limit: int = 100,
) -> tuple[list[str], bool]:
"""List state keys in a scope.
Returns tuple of (keys, has_more).
"""
# Enforce limit cap
limit = min(limit, 100)
async with self._db_engine.connect() as conn:
query = (
select(AgentRunnerState.state_key)
.where(AgentRunnerState.scope_key == scope_key)
.order_by(AgentRunnerState.state_key)
.limit(limit + 1) # Fetch one extra to check has_more
)
if prefix:
query = query.where(
AgentRunnerState.state_key.like(f'{prefix}%')
)
result = await conn.execute(query)
rows = result.fetchall()
keys = [row.state_key for row in rows[:limit]]
has_more = len(rows) > limit
return keys, has_more
async def clear_all(self) -> None:
"""Clear all state entries (for testing)."""
async with self._db_engine.begin() as conn:
await conn.execute(delete(AgentRunnerState))
# Global singleton persistent state store
_persistent_state_store: PersistentStateStore | None = None
_persistent_state_store_lock = threading.Lock()
def get_persistent_state_store(db_engine: AsyncEngine | None = None) -> PersistentStateStore:
"""Get the global persistent state store singleton.
Args:
db_engine: Database engine (required on first call)
Returns:
PersistentStateStore singleton
"""
global _persistent_state_store
with _persistent_state_store_lock:
if _persistent_state_store is None:
if db_engine is None:
raise RuntimeError("db_engine required for first call to get_persistent_state_store")
_persistent_state_store = PersistentStateStore(db_engine)
return _persistent_state_store
def reset_persistent_state_store() -> None:
"""Reset the global persistent state store (for testing)."""
global _persistent_state_store
with _persistent_state_store_lock:
_persistent_state_store = None

View File

@@ -28,6 +28,8 @@ class AgentRunSession(typing.TypedDict):
conversation_id: Conversation ID for history/event access
resources: Authorized resources for this run (from AgentResources)
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
state_policy: State policy from binding (enable_state, state_scopes)
state_context: Context for state API (scope_keys, binding_identity, etc.)
status: Session status tracking
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
"""
@@ -38,6 +40,8 @@ class AgentRunSession(typing.TypedDict):
conversation_id: str | None
resources: AgentResources
permissions: dict[str, list[str]]
state_policy: dict[str, typing.Any] # {enable_state: bool, state_scopes: list}
state_context: dict[str, typing.Any] # {scope_keys: dict, binding_identity: str, ...}
status: AgentRunSessionStatus
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
@@ -70,6 +74,8 @@ class AgentRunSessionRegistry:
resources: AgentResources,
conversation_id: str | None = None,
permissions: dict[str, list[str]] | None = None,
state_policy: dict[str, typing.Any] | None = None,
state_context: dict[str, typing.Any] | None = None,
) -> None:
"""Register a new agent run session.
@@ -81,12 +87,21 @@ class AgentRunSessionRegistry:
resources: Authorized resources for this run
conversation_id: Conversation ID for history/event access
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
state_policy: State policy from binding (enable_state, state_scopes)
state_context: Context for state API (scope_keys, binding_identity, etc.)
"""
now = int(time.time())
# Normalize permissions to empty dict if None
permissions = permissions or {}
# Normalize state_policy to defaults if None
if state_policy is None:
state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
# Normalize state_context to empty dict if None
state_context = state_context or {}
# Pre-compute authorized resource IDs for O(1) lookup
authorized_ids: dict[str, set[str]] = {
'model': {m.get('model_id') for m in resources.get('models', [])},
@@ -95,14 +110,18 @@ class AgentRunSessionRegistry:
'file': {f.get('file_id') for f in resources.get('files', [])},
}
# NOTE: state_policy and state_context are stored at session top-level,
# NOT in resources. Resources should only contain resource authorization info.
session: AgentRunSession = {
'run_id': run_id,
'runner_id': runner_id,
'query_id': query_id,
'plugin_identity': plugin_identity,
'conversation_id': conversation_id,
'resources': resources,
'resources': resources, # Original AgentResources, no state metadata mixed in
'permissions': permissions,
'state_policy': state_policy,
'state_context': state_context,
'status': {
'started_at': now,
'last_activity_at': now,

View File

@@ -0,0 +1,89 @@
"""Agent runner state persistence entity for host-owned state."""
from __future__ import annotations
import sqlalchemy
import datetime
from .base import Base
class AgentRunnerState(Base):
"""AgentRunnerState stores host-owned state for AgentRunner protocol.
State is:
- Host-owned: Managed by LangBot, not by plugin instances
- Scope-isolated: Separated by runner_id + binding_identity + scope
- Policy-enforced: Controlled by StatePolicy (enable_state, state_scopes)
Scope key design:
- conversation: runner_id + binding_id + conversation_id [+ thread_id]
- actor: runner_id + binding_id + actor_type + actor_id
- subject: runner_id + binding_id + subject_type + subject_id
- runner: runner_id + binding_id
This table persists state across runs, replacing the in-memory
RunnerScopedStateStore._store dict.
"""
__tablename__ = 'agent_runner_state'
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
"""Auto-increment ID for sequencing."""
# Identity
runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
"""Runner descriptor ID (plugin:author/name/runner)."""
binding_identity = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
"""Binding identity for isolation (binding_id or scope_type:scope_id)."""
scope = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True)
"""State scope: 'conversation', 'actor', 'subject', or 'runner'."""
scope_key = sqlalchemy.Column(sqlalchemy.String(512), nullable=False, index=True)
"""Full scope key for unique lookup (includes all identity parts)."""
state_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
"""State key within scope (should use namespace prefix like external.*)."""
value_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
"""State value as JSON string (size-limited by host)."""
# Context fields for querying/filtering
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
"""Bot UUID if applicable."""
workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Workspace ID for multi-tenant."""
conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
"""Conversation ID for conversation scope."""
thread_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Thread ID for thread-scoped conversation state."""
actor_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True)
"""Actor type for actor scope."""
actor_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
"""Actor ID for actor scope."""
subject_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=True)
"""Subject type for subject scope."""
subject_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Subject ID for subject scope."""
# Lifecycle
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
"""When this state entry was created."""
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow)
"""When this state entry was last updated."""
# Unique constraint: scope_key + state_key
__table_args__ = (
sqlalchemy.UniqueConstraint('scope_key', 'state_key', name='uq_agent_runner_state_scope_key_state_key'),
sqlalchemy.Index('ix_agent_runner_state_runner_binding', 'runner_id', 'binding_identity'),
sqlalchemy.Index('ix_agent_runner_state_scope_key_lookup', 'scope_key'),
)

View File

@@ -16,6 +16,7 @@ from langbot.pkg.entity.persistence.base import Base
# Import all ORM models so they are registered with Base.metadata
# This is required for autogenerate to detect model changes
from langbot.pkg.entity.persistence import (
agent_runner_state,
apikey,
artifact,
bot,

View File

@@ -0,0 +1,68 @@
# Alembic script.py.mako — template for auto-generated revisions
"""add agent_runner_state table for host-owned persistent state
Revision ID: 6dfd3dd7f0c7
Revises: a1b2c3d4e5f6
Create Date: 2026-05-23 19:49:08.529110
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '6dfd3dd7f0c7'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_runner_state',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('runner_id', sa.String(length=255), nullable=False),
sa.Column('binding_identity', sa.String(length=255), nullable=False),
sa.Column('scope', sa.String(length=50), nullable=False),
sa.Column('scope_key', sa.String(length=512), nullable=False),
sa.Column('state_key', sa.String(length=255), nullable=False),
sa.Column('value_json', sa.Text(), nullable=True),
sa.Column('bot_id', sa.String(length=255), nullable=True),
sa.Column('workspace_id', sa.String(length=255), nullable=True),
sa.Column('conversation_id', sa.String(length=255), nullable=True),
sa.Column('thread_id', sa.String(length=255), nullable=True),
sa.Column('actor_type', sa.String(length=50), nullable=True),
sa.Column('actor_id', sa.String(length=255), nullable=True),
sa.Column('subject_type', sa.String(length=50), nullable=True),
sa.Column('subject_id', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('scope_key', 'state_key', name='uq_agent_runner_state_scope_key_state_key')
)
with op.batch_alter_table('agent_runner_state', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_agent_runner_state_actor_id'), ['actor_id'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_binding_identity'), ['binding_identity'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_bot_id'), ['bot_id'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_conversation_id'), ['conversation_id'], unique=False)
batch_op.create_index('ix_agent_runner_state_runner_binding', ['runner_id', 'binding_identity'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_runner_id'), ['runner_id'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_scope'), ['scope'], unique=False)
batch_op.create_index(batch_op.f('ix_agent_runner_state_scope_key'), ['scope_key'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('agent_runner_state', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_agent_runner_state_scope_key'))
batch_op.drop_index(batch_op.f('ix_agent_runner_state_scope'))
batch_op.drop_index(batch_op.f('ix_agent_runner_state_runner_id'))
batch_op.drop_index('ix_agent_runner_state_runner_binding')
batch_op.drop_index(batch_op.f('ix_agent_runner_state_conversation_id'))
batch_op.drop_index(batch_op.f('ix_agent_runner_state_bot_id'))
batch_op.drop_index(batch_op.f('ix_agent_runner_state_binding_identity'))
batch_op.drop_index(batch_op.f('ix_agent_runner_state_actor_id'))
op.drop_table('agent_runner_state')
# ### end Alembic commands ###

View File

@@ -1348,10 +1348,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1420,10 +1424,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1483,10 +1491,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1538,10 +1550,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1610,10 +1626,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1694,10 +1714,14 @@ class RuntimeConnectionHandler(handler.Handler):
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
@@ -1746,6 +1770,346 @@ class RuntimeConnectionHandler(handler.Handler):
self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Artifact read error: {e}')
# ================= State APIs (run-scoped, policy-enforced) =================
@self.action(PluginToRuntimeAction.STATE_GET)
async def state_get(data: dict[str, Any]) -> handler.ActionResponse:
"""Get a state value from host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Get state policy from session (stored in state_policy field, not in resources)
state_policy = session.get('state_policy', {})
if not state_policy:
# Default state policy
state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
# Check if state is enabled
if not state_policy.get('enable_state', True):
return handler.ActionResponse.error(
message='State access is disabled by binding policy'
)
# Check if scope is enabled
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
if scope not in state_scopes:
return handler.ActionResponse.error(
message=f'Scope "{scope}" is not enabled by binding policy'
)
# Build scope key using state_context from session (stored in state_context field, not in resources)
state_context = session.get('state_context', {})
scope_key = state_context.get('scope_keys', {}).get(scope)
if not scope_key:
return handler.ActionResponse.error(
message=f'Scope key not available for scope "{scope}"'
)
# Get state from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
try:
value = await store.state_get(scope_key, key)
return handler.ActionResponse.success(data={'value': value})
except Exception as e:
self.ap.logger.error(f'STATE_GET error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State get error: {e}')
@self.action(PluginToRuntimeAction.STATE_SET)
async def state_set(data: dict[str, Any]) -> handler.ActionResponse:
"""Set a state value in host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
Value must be JSON-serializable and size-limited.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
value = data.get('value')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Get state policy from session (stored in state_policy field, not in resources)
state_policy = session.get('state_policy', {})
if not state_policy:
state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
# Check if state is enabled
if not state_policy.get('enable_state', True):
return handler.ActionResponse.error(
message='State access is disabled by binding policy'
)
# Check if scope is enabled
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
if scope not in state_scopes:
return handler.ActionResponse.error(
message=f'Scope "{scope}" is not enabled by binding policy'
)
# Build scope key using state_context from session (stored in state_context field, not in resources)
state_context = session.get('state_context', {})
scope_key = state_context.get('scope_keys', {}).get(scope)
if not scope_key:
return handler.ActionResponse.error(
message=f'Scope key not available for scope "{scope}"'
)
# Get additional context for DB insert
runner_id = session.get('runner_id', '')
binding_identity = state_context.get('binding_identity', 'unknown')
# Set state in persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
try:
success, error = await store.state_set(
scope_key=scope_key,
state_key=key,
value=value,
runner_id=runner_id,
binding_identity=binding_identity,
scope=scope,
context=state_context,
logger=self.ap.logger,
)
if not success:
return handler.ActionResponse.error(message=error or 'Failed to set state')
return handler.ActionResponse.success(data={'success': True})
except Exception as e:
self.ap.logger.error(f'STATE_SET error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State set error: {e}')
@self.action(PluginToRuntimeAction.STATE_DELETE)
async def state_delete(data: dict[str, Any]) -> handler.ActionResponse:
"""Delete a state value from host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Get state policy from session (stored in state_policy field, not in resources)
state_policy = session.get('state_policy', {})
if not state_policy:
state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
# Check if state is enabled
if not state_policy.get('enable_state', True):
return handler.ActionResponse.error(
message='State access is disabled by binding policy'
)
# Check if scope is enabled
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
if scope not in state_scopes:
return handler.ActionResponse.error(
message=f'Scope "{scope}" is not enabled by binding policy'
)
# Build scope key using state_context from session (stored in state_context field, not in resources)
state_context = session.get('state_context', {})
scope_key = state_context.get('scope_keys', {}).get(scope)
if not scope_key:
return handler.ActionResponse.error(
message=f'Scope key not available for scope "{scope}"'
)
# Delete state from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
try:
deleted = await store.state_delete(scope_key, key)
return handler.ActionResponse.success(data={'success': deleted})
except Exception as e:
self.ap.logger.error(f'STATE_DELETE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State delete error: {e}')
@self.action(PluginToRuntimeAction.STATE_LIST)
async def state_list(data: dict[str, Any]) -> handler.ActionResponse:
"""List state keys in a scope.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
prefix = data.get('prefix')
limit = data.get('limit', 100)
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
# Validate limit
if not isinstance(limit, int) or limit <= 0:
limit = 100
limit = min(limit, 100) # Cap at 100
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity (strict: required when session has plugin_identity)
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity:
if not caller_plugin_identity:
return handler.ActionResponse.error(
message=f'caller_plugin_identity is required for run_id {run_id}'
)
if caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Get state policy from session (stored in state_policy field, not in resources)
state_policy = session.get('state_policy', {})
if not state_policy:
state_policy = {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
# Check if state is enabled
if not state_policy.get('enable_state', True):
return handler.ActionResponse.error(
message='State access is disabled by binding policy'
)
# Check if scope is enabled
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
if scope not in state_scopes:
return handler.ActionResponse.error(
message=f'Scope "{scope}" is not enabled by binding policy'
)
# Build scope key using state_context from session (stored in state_context field, not in resources)
state_context = session.get('state_context', {})
scope_key = state_context.get('scope_keys', {}).get(scope)
if not scope_key:
return handler.ActionResponse.error(
message=f'Scope key not available for scope "{scope}"'
)
# List state keys from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
try:
keys, has_more = await store.state_list(scope_key, prefix, limit)
return handler.ActionResponse.success(data={
'keys': keys,
'has_more': has_more,
})
except Exception as e:
self.ap.logger.error(f'STATE_LIST error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State list error: {e}')
@self.action(CommonAction.PING)
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
"""Ping"""

View File

@@ -0,0 +1,361 @@
"""Tests for ContextAccess.state determination in AgentRunContextBuilder.
Tests focus on:
- Event-first mode: state=True when enable_state=True and state_scopes non-empty
- Event-first mode: state=False when enable_state=False
- Legacy Query mode: state=False (no persistent state API)
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
class MockApplication:
"""Mock Application for testing."""
def __init__(self):
self.logger = MagicMock()
self.persistence_mgr = MagicMock()
self.persistence_mgr.get_db_engine = MagicMock()
class TestContextAccessStateDetermination:
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
@pytest.fixture
def mock_app(self):
"""Create mock application."""
return MockApplication()
@pytest.fixture
def mock_event(self):
"""Create mock event envelope."""
return AgentEventEnvelope(
event_id='evt_001',
event_type='message.received',
event_time=1234567890,
source='test',
bot_id='bot_001',
workspace_id='ws_001',
conversation_id='conv_001',
thread_id=None,
actor=ActorContext(actor_type='user', actor_id='user_001'),
subject=None,
input=AgentInput(text='hello', contents=[], attachments=[]),
delivery=DeliveryContext(surface='test', supports_streaming=True),
)
@pytest.fixture
def mock_descriptor(self):
"""Create mock runner descriptor."""
descriptor = MagicMock()
descriptor.id = 'plugin:test/runner/default'
descriptor.protocol_version = '1.0'
descriptor.permissions = {}
return descriptor
@pytest.mark.asyncio
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
"""ContextAccess.state=True when enable_state=True and state_scopes non-empty."""
# Create binding with state enabled and non-empty scopes
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(
enable_state=True,
state_scopes=['conversation', 'actor'],
),
)
builder = AgentRunContextBuilder(mock_app)
# Real call to _build_context_access
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Verify state=True based on binding.state_policy
assert context_access['available_apis']['state'] is True
@pytest.mark.asyncio
async def test_enable_state_false_sets_state_false(self, mock_app, mock_event, mock_descriptor):
"""ContextAccess.state=False when enable_state=False."""
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(
enable_state=False,
state_scopes=[],
),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Verify state=False
assert context_access['available_apis']['state'] is False
@pytest.mark.asyncio
async def test_enable_state_true_empty_scopes_sets_state_false(self, mock_app, mock_event, mock_descriptor):
"""ContextAccess.state=False when enable_state=True but state_scopes empty."""
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(
enable_state=True,
state_scopes=[], # Empty scopes - state not available
),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Verify state=False (empty scopes means state not available)
assert context_access['available_apis']['state'] is False
@pytest.mark.asyncio
async def test_no_binding_sets_state_false(self, mock_app, mock_event, mock_descriptor):
"""ContextAccess.state=False when binding is None (legacy mode)."""
builder = AgentRunContextBuilder(mock_app)
# Real call without binding
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding=None)
# Verify state=False (no binding = no state policy = state disabled)
assert context_access['available_apis']['state'] is False
@pytest.mark.asyncio
async def test_runner_scope_available_without_conversation(self, mock_app, mock_descriptor):
"""State API with runner scope is available even without conversation_id."""
mock_event = AgentEventEnvelope(
event_id='evt_002',
event_type='message.received',
event_time=1234567890,
source='test',
bot_id='bot_001',
workspace_id='ws_001',
conversation_id=None, # No conversation
thread_id=None,
actor=ActorContext(actor_type='user', actor_id='user_001'),
subject=None,
input=AgentInput(text='hello', contents=[], attachments=[]),
delivery=DeliveryContext(surface='test', supports_streaming=True),
)
binding = AgentBinding(
binding_id='binding_002',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='workspace', scope_id='ws_001'),
state_policy=StatePolicy(
enable_state=True,
state_scopes=['runner'], # Runner scope doesn't need conversation_id
),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# State should be True because runner scope is enabled
assert context_access['available_apis']['state'] is True
@pytest.mark.asyncio
async def test_multiple_scopes_all_available(self, mock_app, mock_event, mock_descriptor):
"""State API with multiple scopes enabled."""
binding = AgentBinding(
binding_id='binding_003',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(
enable_state=True,
state_scopes=['conversation', 'actor', 'subject', 'runner'],
),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# State should be True with all scopes enabled
assert context_access['available_apis']['state'] is True
class TestStatePolicyFromBinding:
"""Tests for state_policy extraction from binding."""
def test_state_policy_structure(self):
"""State policy has correct structure."""
policy = StatePolicy(
enable_state=True,
state_scopes=['conversation', 'actor', 'subject', 'runner'],
)
assert policy.enable_state is True
assert len(policy.state_scopes) == 4
assert 'conversation' in policy.state_scopes
def test_state_policy_disabled(self):
"""State policy can be disabled."""
policy = StatePolicy(
enable_state=False,
state_scopes=[],
)
assert policy.enable_state is False
assert len(policy.state_scopes) == 0
class TestBindingWithStatePolicy:
"""Tests for binding with state_policy."""
def test_binding_contains_state_policy(self):
"""Binding contains state_policy field."""
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(
enable_state=True,
state_scopes=['conversation'],
),
)
assert binding.state_policy is not None
assert binding.state_policy.enable_state is True
class TestContextAccessOtherAPIs:
"""Tests for other available_apis fields based on permissions."""
@pytest.fixture
def mock_app(self):
"""Create mock application."""
return MockApplication()
@pytest.mark.asyncio
async def test_history_apis_based_on_permissions(self, mock_app):
"""History APIs availability based on runner permissions."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'history': ['page', 'search'],
}
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# History APIs enabled based on permissions
assert context_access['available_apis']['history_page'] is True
assert context_access['available_apis']['history_search'] is True
@pytest.mark.asyncio
async def test_event_apis_based_on_permissions(self, mock_app):
"""Event APIs availability based on runner permissions."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'events': ['get', 'page'],
}
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Event APIs enabled based on permissions
assert context_access['available_apis']['event_get'] is True
assert context_access['available_apis']['event_page'] is True
@pytest.mark.asyncio
async def test_artifact_apis_based_on_permissions(self, mock_app):
"""Artifact APIs availability based on runner permissions."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'artifacts': ['metadata', 'read'],
}
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Artifact APIs enabled based on permissions
assert context_access['available_apis']['artifact_metadata'] is True
assert context_access['available_apis']['artifact_read'] is True
@pytest.mark.asyncio
async def test_no_permissions_all_apis_disabled(self, mock_app):
"""All pull APIs disabled when permissions are empty."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {} # No permissions
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
)
builder = AgentRunContextBuilder(mock_app)
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# All pull APIs should be disabled
assert context_access['available_apis']['history_page'] is False
assert context_access['available_apis']['history_search'] is False
assert context_access['available_apis']['event_get'] is False
assert context_access['available_apis']['event_page'] is False
assert context_access['available_apis']['artifact_metadata'] is False
assert context_access['available_apis']['artifact_read'] is False
assert context_access['available_apis']['state'] is False

View File

@@ -113,13 +113,24 @@ class TestContextValidation:
resources = self._make_resources()
descriptor = self._make_descriptor()
# Build context
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Mock persistent state store to return empty state snapshot
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
mock_store = AsyncMock()
mock_store.build_snapshot_from_event = AsyncMock(return_value={
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
})
mock_get_store.return_value = mock_store
# Build context
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Validate it can be parsed by SDK AgentRunContext
# This will raise ValidationError if invalid
@@ -162,12 +173,23 @@ class TestContextValidation:
resources = self._make_resources()
descriptor = self._make_descriptor()
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Mock persistent state store to return empty state snapshot
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
mock_store = AsyncMock()
mock_store.build_snapshot_from_event = AsyncMock(return_value={
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
})
mock_get_store.return_value = mock_store
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Protocol v1 does NOT have these as core fields
assert 'messages' not in context_dict, "messages should not be top-level in Protocol v1"
@@ -192,12 +214,23 @@ class TestContextValidation:
resources = self._make_resources()
descriptor = self._make_descriptor()
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Mock persistent state store to return empty state snapshot
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
mock_store = AsyncMock()
mock_store.build_snapshot_from_event = AsyncMock(return_value={
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
})
mock_get_store.return_value = mock_store
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# event is REQUIRED in Protocol v1
assert context_dict.get('event') is not None, "event is REQUIRED for Protocol v1"
@@ -217,12 +250,23 @@ class TestContextValidation:
resources = self._make_resources()
descriptor = self._make_descriptor()
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# Mock persistent state store to return empty state snapshot
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
mock_store = AsyncMock()
mock_store.build_snapshot_from_event = AsyncMock(return_value={
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
})
mock_get_store.return_value = mock_store
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
# delivery is REQUIRED in Protocol v1
assert context_dict.get('delivery') is not None, "delivery is REQUIRED for Protocol v1"

View File

@@ -0,0 +1,538 @@
"""Tests for State API handler authorization in RuntimeConnectionHandler.
Tests focus on:
- STATE_GET authorization
- STATE_SET authorization
- STATE_DELETE authorization
- STATE_LIST authorization
These tests instantiate real RuntimeConnectionHandler action handlers and verify:
- Authorization errors for missing/mismatched caller_plugin_identity
- Authorization errors for disabled state or scope
- Full flow: set -> get -> list -> delete with real SQLite
Authorization rules:
- caller_plugin_identity is REQUIRED when session has plugin_identity
- caller_plugin_identity must match session's plugin_identity
- enable_state must be True
- scope must be in state_scopes
"""
from __future__ import annotations
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry, get_session_registry
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore, reset_persistent_state_store
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
from langbot_plugin.runtime.io.connection import Connection
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
# Import shared test fixtures
from .conftest import make_resources
class FakeConnection:
"""Fake connection for testing."""
pass
class FakeApplication:
"""Fake Application for testing."""
def __init__(self, db_engine=None):
self.logger = MagicMock()
self.logger.debug = MagicMock()
self.logger.warning = MagicMock()
self.logger.error = MagicMock()
self.persistence_mgr = MagicMock()
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
@pytest.fixture
def session_registry():
"""Create a fresh session registry for each test."""
return AgentRunSessionRegistry()
@pytest.fixture
async def db_engine():
"""Create an in-memory SQLite database for testing."""
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
yield engine
await engine.dispose()
@pytest.fixture
async def persistent_store(db_engine):
"""Create a persistent state store with real SQLite."""
reset_persistent_state_store()
store = PersistentStateStore(db_engine)
# Create the table
from langbot.pkg.entity.persistence.agent_runner_state import AgentRunnerState
from sqlalchemy import text
async with db_engine.begin() as conn:
await conn.run_sync(AgentRunnerState.__table__.create, checkfirst=True)
yield store
reset_persistent_state_store()
class TestStateAPIHandlerAuthorization:
"""Tests for State API handler authorization with real action calls."""
@pytest.mark.asyncio
async def test_state_get_missing_run_id_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: missing run_id returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
# Get the STATE_GET action handler (actions dict is keyed by action value string)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Call without run_id
result = await state_get_handler({'scope': 'conversation', 'key': 'test_key'})
assert result.code != 0
assert 'run_id is required' in result.message
@pytest.mark.asyncio
async def test_state_get_run_not_found_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: run_id not in session registry returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Call with non-existent run_id
result = await state_get_handler({
'run_id': 'nonexistent_run',
'scope': 'conversation',
'key': 'test_key',
})
assert result.code != 0
assert 'not found' in result.message.lower()
@pytest.mark.asyncio
async def test_state_get_missing_caller_plugin_identity_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: missing caller_plugin_identity when session has plugin_identity returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
# Register session with plugin_identity
await session_registry.register(
run_id='run_test_missing_identity',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Call without caller_plugin_identity
result = await state_get_handler({
'run_id': 'run_test_missing_identity',
'scope': 'conversation',
'key': 'test_key',
})
assert result.code != 0
assert 'caller_plugin_identity is required' in result.message
await session_registry.unregister('run_test_missing_identity')
@pytest.mark.asyncio
async def test_state_get_caller_identity_mismatch_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: caller_plugin_identity mismatch returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
await session_registry.register(
run_id='run_test_mismatch',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Call with wrong caller_plugin_identity
result = await state_get_handler({
'run_id': 'run_test_mismatch',
'scope': 'conversation',
'key': 'test_key',
'caller_plugin_identity': 'other/plugin',
})
assert result.code != 0
assert 'mismatch' in result.message.lower()
await session_registry.unregister('run_test_mismatch')
@pytest.mark.asyncio
async def test_state_get_enable_state_false_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: enable_state=False returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
await session_registry.register(
run_id='run_test_disabled',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': False, 'state_scopes': []},
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
result = await state_get_handler({
'run_id': 'run_test_disabled',
'scope': 'conversation',
'key': 'test_key',
'caller_plugin_identity': 'test/runner',
})
assert result.code != 0
assert 'disabled' in result.message.lower()
await session_registry.unregister('run_test_disabled')
@pytest.mark.asyncio
async def test_state_get_scope_not_enabled_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: scope not in state_scopes returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
await session_registry.register(
run_id='run_test_scope_disabled',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {'conversation': 'conv_key', 'actor': 'actor_key'}, 'binding_identity': 'binding_1'},
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Request 'actor' scope which is not in state_scopes
result = await state_get_handler({
'run_id': 'run_test_scope_disabled',
'scope': 'actor',
'key': 'test_key',
'caller_plugin_identity': 'test/runner',
})
assert result.code != 0
assert 'not enabled' in result.message.lower() or 'scope' in result.message.lower()
await session_registry.unregister('run_test_scope_disabled')
@pytest.mark.asyncio
async def test_state_get_missing_scope_key_returns_error(self, session_registry, db_engine, persistent_store):
"""STATE_GET: missing scope_key in state_context returns error."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
await session_registry.register(
run_id='run_test_no_scope_key',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, # No scope_keys
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
result = await state_get_handler({
'run_id': 'run_test_no_scope_key',
'scope': 'conversation',
'key': 'test_key',
'caller_plugin_identity': 'test/runner',
})
assert result.code != 0
assert 'not available' in result.message.lower()
await session_registry.unregister('run_test_no_scope_key')
class TestStateAPIFullFlowWithRealDB:
"""Tests for complete State API flow with real SQLite database."""
@pytest.mark.asyncio
async def test_state_set_get_list_delete_flow(self, session_registry, db_engine, persistent_store):
"""Test complete state flow: set -> get -> list -> delete with real SQLite."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
# Register session
await session_registry.register(
run_id='run_full_flow',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation', 'runner']},
state_context={
'scope_keys': {
'conversation': 'conv:test_runner:binding_1:conv_123',
'runner': 'runner:test_runner:binding_1',
},
'binding_identity': 'binding_1',
'conversation_id': 'conv_123',
},
)
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
# Verify session has correct state_context
session = await session_registry.get('run_full_flow')
assert session is not None
state_ctx = session.get('state_context')
assert state_ctx is not None, f"state_context is None. Session keys: {list(session.keys())}"
assert 'scope_keys' in state_ctx, f"scope_keys not in state_context: {state_ctx}"
assert 'conversation' in state_ctx['scope_keys'], f"conversation not in scope_keys: {state_ctx['scope_keys']}"
# Get handlers (actions dict is keyed by action value string)
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
state_list_handler = handler.actions[PluginToRuntimeAction.STATE_LIST.value]
state_delete_handler = handler.actions[PluginToRuntimeAction.STATE_DELETE.value]
# 1. STATE_SET
set_result = await state_set_handler({
'run_id': 'run_full_flow',
'scope': 'conversation',
'key': 'external.test_key',
'value': {'data': 'test_value'},
'caller_plugin_identity': 'test/runner',
})
assert set_result.code == 0
assert set_result.data.get('success') is True
# 2. STATE_GET
get_result = await state_get_handler({
'run_id': 'run_full_flow',
'scope': 'conversation',
'key': 'external.test_key',
'caller_plugin_identity': 'test/runner',
})
assert get_result.code == 0
assert get_result.data.get('value') == {'data': 'test_value'}
# 3. STATE_LIST
list_result = await state_list_handler({
'run_id': 'run_full_flow',
'scope': 'conversation',
'prefix': 'external.',
'caller_plugin_identity': 'test/runner',
})
assert list_result.code == 0
keys = list_result.data.get('keys', [])
assert 'external.test_key' in keys
# 4. STATE_DELETE
delete_result = await state_delete_handler({
'run_id': 'run_full_flow',
'scope': 'conversation',
'key': 'external.test_key',
'caller_plugin_identity': 'test/runner',
})
assert delete_result.code == 0
# 5. Verify deleted
get_after_delete = await state_get_handler({
'run_id': 'run_full_flow',
'scope': 'conversation',
'key': 'external.test_key',
'caller_plugin_identity': 'test/runner',
})
assert get_after_delete.code == 0
assert get_after_delete.data.get('value') is None
await session_registry.unregister('run_full_flow')
class TestStateHandlerReadsFromSessionTopLevel:
"""Tests verifying handlers read state_policy/state_context from session top-level, not resources."""
@pytest.mark.asyncio
async def test_state_handler_reads_state_policy_from_session_top_level(self, session_registry, db_engine, persistent_store):
"""Handler reads state_policy from session['state_policy'], not session['resources']['state_policy']."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
# Register with explicit state_policy at top level
await session_registry.register(
run_id='run_policy_top_level',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': False, 'state_scopes': []}, # Disabled at top level
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
)
# Verify resources does NOT contain state_policy
session = await session_registry.get('run_policy_top_level')
assert session is not None
assert 'state_policy' not in session.get('resources', {}), \
"resources should NOT contain state_policy"
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
# Should fail because enable_state=False in session['state_policy']
result = await state_get_handler({
'run_id': 'run_policy_top_level',
'scope': 'conversation',
'key': 'test_key',
'caller_plugin_identity': 'test/runner',
})
assert result.code != 0
assert 'disabled' in result.message.lower()
await session_registry.unregister('run_policy_top_level')
@pytest.mark.asyncio
async def test_state_handler_reads_state_context_from_session_top_level(self, session_registry, db_engine, persistent_store):
"""Handler reads state_context from session['state_context'], not session['resources']['state_context']."""
fake_app = FakeApplication(db_engine)
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
# Register with explicit state_context at top level
await session_registry.register(
run_id='run_context_top_level',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=make_resources(),
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {'conversation': 'conv_key_xyz'}, 'binding_identity': 'binding_xyz'},
)
# Verify resources does NOT contain state_context
session = await session_registry.get('run_context_top_level')
assert session is not None
assert 'state_context' not in session.get('resources', {}), \
"resources should NOT contain state_context"
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
# Should use scope_key from session['state_context']['scope_keys']['conversation']
result = await state_set_handler({
'run_id': 'run_context_top_level',
'scope': 'conversation',
'key': 'test_key',
'value': 'test_value',
'caller_plugin_identity': 'test/runner',
})
# Should succeed - scope_key was found in state_context
assert result.code == 0
await session_registry.unregister('run_context_top_level')
class TestResourcesDoesNotContainStateMetadata:
"""Tests verifying resources is clean - no state metadata mixed in."""
@pytest.mark.asyncio
async def test_resources_clean_after_register(self, session_registry):
"""After register(), resources should not contain state_policy or state_context."""
resources = make_resources()
await session_registry.register(
run_id='run_resources_clean',
runner_id='plugin:test/runner/default',
query_id=1,
plugin_identity='test/runner',
resources=resources,
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
)
session = await session_registry.get('run_resources_clean')
assert session is not None
# Verify resources is clean
session_resources = session.get('resources', {})
assert 'state_policy' not in session_resources, \
"session['resources'] should NOT contain state_policy"
assert 'state_context' not in session_resources, \
"session['resources'] should NOT contain state_context"
# Verify state metadata is at top level
assert 'state_policy' in session
assert 'state_context' in session
await session_registry.unregister('run_resources_clean')

View File

@@ -1137,4 +1137,238 @@ class TestStateStorePolicyEnforcement:
)
assert result is False
assert any('not enabled' in w for w in logger.warnings)
assert any('not enabled' in w for w in logger.warnings)
# ========== Persistent State Store Tests ==========
import pytest
import asyncio
import tempfile
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
class TestPersistentStateStore:
"""Tests for persistent database-backed state store."""
@pytest.fixture
async def db_engine(self):
"""Create a temporary async SQLite database for testing."""
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
db_path = f.name
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
# Create tables
from langbot.pkg.entity.persistence.base import Base
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
# Cleanup
await engine.dispose()
os.unlink(db_path)
@pytest.fixture
async def persistent_store(self, db_engine):
"""Create a persistent state store for testing."""
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
store = PersistentStateStore(db_engine)
yield store
await store.clear_all()
@pytest.mark.asyncio
async def test_build_snapshot_empty(self, persistent_store):
"""Building snapshot from empty store returns empty scopes."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
assert snapshot['conversation'] == {'external.conversation_id': 'conv_001'}
assert snapshot['actor'] == {}
assert snapshot['subject'] == {}
assert snapshot['runner'] == {}
@pytest.mark.asyncio
async def test_state_set_and_get(self, persistent_store):
"""State set/get round trip."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
# Set state
success, error = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'test_key', {'nested': 'value'}, None
)
assert success is True
assert error is None
# Get via snapshot
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
@pytest.mark.asyncio
async def test_binding_isolation(self, persistent_store):
"""Different binding_id should have isolated state."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding_a = FakeBinding(binding_id='binding_a')
binding_b = FakeBinding(binding_id='binding_b')
# Set for binding_a
await persistent_store.apply_update_from_event(
event, binding_a, descriptor, 'conversation', 'key', 'value_a', None
)
# binding_b should not see binding_a's state
snapshot_b = await persistent_store.build_snapshot_from_event(event, binding_b, descriptor)
assert snapshot_b['conversation'] == {'external.conversation_id': 'conv_001'}
# binding_a should see its own state
snapshot_a = await persistent_store.build_snapshot_from_event(event, binding_a, descriptor)
assert snapshot_a['conversation']['key'] == 'value_a'
@pytest.mark.asyncio
async def test_policy_disable_state(self, persistent_store):
"""enable_state=False should return empty snapshot and reject updates."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
policy = StatePolicy(enable_state=False)
binding = FakeBinding(state_policy=policy)
# Snapshot should be empty
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
assert snapshot == {'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}}
# Update should be rejected
success, error = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'key', 'value', None
)
assert success is False
assert 'disabled' in error.lower()
@pytest.mark.asyncio
async def test_policy_scope_restriction(self, persistent_store):
"""state_scopes should restrict which scopes are accessible."""
descriptor = make_descriptor()
event = FakeEventEnvelope(
conversation_id='conv_001',
actor=FakeActorContext(actor_id='user_001'),
)
policy = StatePolicy(state_scopes=['conversation']) # Only conversation
binding = FakeBinding(state_policy=policy)
# Conversation should work
success_conv, _ = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'key', 'value_conv', None
)
assert success_conv is True
# Actor should be rejected
success_actor, error_actor = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'actor', 'key', 'value_actor', None
)
assert success_actor is False
assert 'not enabled' in error_actor.lower()
@pytest.mark.asyncio
async def test_value_json_size_limit(self, persistent_store):
"""Value exceeding size limit should be rejected."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
# Create a large value (> 256KB)
large_value = 'x' * (300 * 1024)
success, error = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'key', large_value, None
)
assert success is False
assert 'exceeds limit' in error.lower()
@pytest.mark.asyncio
async def test_value_not_json_serializable(self, persistent_store):
"""Non-JSON-serializable value should be rejected."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
# Create a non-serializable value (set is not JSON-serializable)
non_serializable = {'key': {1, 2, 3}}
success, error = await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'key', non_serializable, None
)
assert success is False
assert 'json' in error.lower()
@pytest.mark.asyncio
async def test_state_list(self, persistent_store):
"""State list should return keys with optional prefix filter."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
# Set multiple keys
await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'external.id', '123', None
)
await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'external.name', 'test', None
)
await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'memory.key', 'value', None
)
# Build scope key for list
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
temp_store = PersistentStateStore(None)
scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor)
# List all keys
keys, has_more = await persistent_store.state_list(scope_key)
assert len(keys) == 3
assert has_more is False
# List with prefix
keys_ext, _ = await persistent_store.state_list(scope_key, prefix='external.')
assert len(keys_ext) == 2
assert 'external.id' in keys_ext
assert 'external.name' in keys_ext
@pytest.mark.asyncio
async def test_state_delete(self, persistent_store):
"""State delete should remove key."""
descriptor = make_descriptor()
event = FakeEventEnvelope(conversation_id='conv_001')
binding = FakeBinding()
# Set and verify
await persistent_store.apply_update_from_event(
event, binding, descriptor, 'conversation', 'key', 'value', None
)
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
assert snapshot['conversation']['key'] == 'value'
# Build scope key for delete
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
temp_store = PersistentStateStore(None)
scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor)
# Delete
deleted = await persistent_store.state_delete(scope_key, 'key')
assert deleted is True
# Verify deleted
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
assert 'key' not in snapshot['conversation']
# Delete non-existent should return False
deleted_again = await persistent_store.state_delete(scope_key, 'key')
assert deleted_again is False