mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 01:06:03 +00:00
Migrates legacy runner blocks into plugin runner configs, preserves run-scoped history boundaries, enforces operation/file authorization, and sanitizes inline attachment persistence. Also fixes plugin runner form dirty handling and adds regression coverage.
436 lines
14 KiB
Python
436 lines
14 KiB
Python
"""Persistent state store for AgentRunner protocol state.
|
|
|
|
This module provides a database-backed state store for event-first Protocol v1.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import typing
|
|
import json
|
|
import threading
|
|
from datetime import datetime
|
|
|
|
import sqlalchemy
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
from sqlalchemy import select, delete, update
|
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from .descriptor import AgentRunnerDescriptor
|
|
from .host_models import AgentEventEnvelope, AgentBinding
|
|
from .state_scope import (
|
|
VALID_STATE_SCOPES,
|
|
build_state_scope_key,
|
|
get_binding_identity,
|
|
normalize_state_key,
|
|
)
|
|
from ...entity.persistence.agent_runner_state import AgentRunnerState
|
|
|
|
|
|
# Maximum value_json size (256KB)
|
|
MAX_VALUE_JSON_BYTES = 256 * 1024
|
|
|
|
|
|
class PersistentStateStore:
|
|
"""Database-backed state store for AgentRunner protocol state.
|
|
|
|
IMPORTANT: This is HOST-OWNED protocol state, NOT plugin instance state.
|
|
|
|
This store provides:
|
|
1. Persistent storage across runs via database
|
|
2. Scope isolation by runner_id + binding_identity + scope
|
|
3. Policy enforcement (enable_state, state_scopes)
|
|
4. JSON value validation and size limits
|
|
|
|
Used by:
|
|
- Event-first Protocol v1 (async methods)
|
|
- State API handlers (get/set/delete/list)
|
|
"""
|
|
|
|
def __init__(self, db_engine: AsyncEngine):
|
|
self._db_engine = db_engine
|
|
|
|
def _get_scope_key(
|
|
self,
|
|
scope: str,
|
|
event: AgentEventEnvelope,
|
|
binding: AgentBinding,
|
|
descriptor: AgentRunnerDescriptor,
|
|
) -> str | None:
|
|
"""Get scope key for given scope."""
|
|
return build_state_scope_key(scope, event, binding, descriptor)
|
|
|
|
def _check_scope_enabled(self, scope: str, binding: AgentBinding) -> bool:
|
|
"""Check if scope is enabled by binding's state_policy."""
|
|
state_policy = binding.state_policy
|
|
if not state_policy.enable_state:
|
|
return False
|
|
return scope in state_policy.state_scopes
|
|
|
|
def _validate_json_value(
|
|
self,
|
|
value: typing.Any,
|
|
logger: typing.Any = None,
|
|
) -> tuple[str | None, str | None]:
|
|
"""Validate and serialize value to JSON.
|
|
|
|
Returns:
|
|
Tuple of (json_string, error_message). If error_message is not None,
|
|
json_string will be None.
|
|
"""
|
|
try:
|
|
json_str = json.dumps(value, ensure_ascii=False)
|
|
except (TypeError, ValueError) as e:
|
|
return None, f'Value is not JSON-serializable: {e}'
|
|
|
|
# Check size limit
|
|
json_bytes = len(json_str.encode('utf-8'))
|
|
if json_bytes > MAX_VALUE_JSON_BYTES:
|
|
return None, f'Value size {json_bytes} bytes exceeds limit {MAX_VALUE_JSON_BYTES} bytes'
|
|
|
|
return json_str, None
|
|
|
|
async def _upsert_state_row(
|
|
self,
|
|
conn: typing.Any,
|
|
values: dict[str, typing.Any],
|
|
) -> None:
|
|
"""Insert or update a state row by the logical scope/key identity."""
|
|
update_values = {
|
|
'value_json': values['value_json'],
|
|
'updated_at': values['updated_at'],
|
|
}
|
|
constraint_columns = ['scope_key', 'state_key']
|
|
dialect_name = self._db_engine.dialect.name
|
|
|
|
if dialect_name == 'sqlite':
|
|
stmt = sqlite_insert(AgentRunnerState).values(**values)
|
|
await conn.execute(
|
|
stmt.on_conflict_do_update(
|
|
index_elements=constraint_columns,
|
|
set_=update_values,
|
|
)
|
|
)
|
|
return
|
|
|
|
if dialect_name == 'postgresql':
|
|
stmt = postgresql_insert(AgentRunnerState).values(**values)
|
|
await conn.execute(
|
|
stmt.on_conflict_do_update(
|
|
index_elements=constraint_columns,
|
|
set_=update_values,
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
await conn.execute(sqlalchemy.insert(AgentRunnerState).values(**values))
|
|
except IntegrityError:
|
|
await conn.execute(
|
|
update(AgentRunnerState)
|
|
.where(AgentRunnerState.scope_key == values['scope_key'])
|
|
.where(AgentRunnerState.state_key == values['state_key'])
|
|
.values(**update_values)
|
|
)
|
|
|
|
# ========== Async DB Operations ==========
|
|
|
|
async def build_snapshot_from_event(
|
|
self,
|
|
event: AgentEventEnvelope,
|
|
binding: AgentBinding,
|
|
descriptor: AgentRunnerDescriptor,
|
|
) -> dict[str, dict[str, typing.Any]]:
|
|
"""Build state snapshot for all scopes from event and binding.
|
|
|
|
Reads from database, respects state_policy.
|
|
"""
|
|
state_policy = binding.state_policy
|
|
|
|
# If state is disabled, return all empty scopes
|
|
if not state_policy.enable_state:
|
|
return {
|
|
'conversation': {},
|
|
'actor': {},
|
|
'subject': {},
|
|
'runner': {},
|
|
}
|
|
|
|
snapshot: dict[str, dict[str, typing.Any]] = {
|
|
'conversation': {},
|
|
'actor': {},
|
|
'subject': {},
|
|
'runner': {},
|
|
}
|
|
|
|
async with self._db_engine.connect() as conn:
|
|
for scope in VALID_STATE_SCOPES:
|
|
if not self._check_scope_enabled(scope, binding):
|
|
continue
|
|
|
|
scope_key = self._get_scope_key(scope, event, binding, descriptor)
|
|
if not scope_key:
|
|
continue
|
|
|
|
# Query all state entries for this scope_key
|
|
result = await conn.execute(
|
|
select(AgentRunnerState.state_key, AgentRunnerState.value_json)
|
|
.where(AgentRunnerState.scope_key == scope_key)
|
|
)
|
|
rows = result.fetchall()
|
|
|
|
for row in rows:
|
|
key = row.state_key
|
|
value_json = row.value_json
|
|
if value_json:
|
|
try:
|
|
snapshot[scope][key] = json.loads(value_json)
|
|
except json.JSONDecodeError:
|
|
pass # Skip invalid JSON
|
|
|
|
# Seed external.conversation_id from event.conversation_id if not set
|
|
if self._check_scope_enabled('conversation', binding) and event.conversation_id:
|
|
if 'external.conversation_id' not in snapshot['conversation']:
|
|
snapshot['conversation']['external.conversation_id'] = event.conversation_id
|
|
|
|
return snapshot
|
|
|
|
async def apply_update_from_event(
|
|
self,
|
|
event: AgentEventEnvelope,
|
|
binding: AgentBinding,
|
|
descriptor: AgentRunnerDescriptor,
|
|
scope: str,
|
|
key: str,
|
|
value: typing.Any,
|
|
logger: typing.Any = None,
|
|
) -> tuple[bool, str | None]:
|
|
"""Apply a state update from event context.
|
|
|
|
Returns:
|
|
Tuple of (success, error_message). If success is False, error_message
|
|
contains the reason.
|
|
"""
|
|
state_policy = binding.state_policy
|
|
|
|
# Check if state is disabled
|
|
if not state_policy.enable_state:
|
|
return False, 'State is disabled by binding policy'
|
|
|
|
# Validate scope
|
|
if scope not in VALID_STATE_SCOPES:
|
|
return False, f'Invalid scope: {scope}'
|
|
|
|
# Check if scope is enabled
|
|
if not self._check_scope_enabled(scope, binding):
|
|
return False, f'Scope "{scope}" not enabled by binding policy'
|
|
|
|
# Map accepted key aliases
|
|
key = normalize_state_key(key)
|
|
|
|
# Get scope key
|
|
scope_key = self._get_scope_key(scope, event, binding, descriptor)
|
|
if not scope_key:
|
|
return False, f'Missing identity for scope "{scope}"'
|
|
|
|
# Validate and serialize value
|
|
value_json, error = self._validate_json_value(value, logger)
|
|
if error:
|
|
return False, error
|
|
|
|
# Build context fields
|
|
binding_identity = get_binding_identity(binding)
|
|
|
|
now = datetime.utcnow()
|
|
async with self._db_engine.begin() as conn:
|
|
await self._upsert_state_row(
|
|
conn,
|
|
{
|
|
'runner_id': descriptor.id,
|
|
'binding_identity': binding_identity,
|
|
'scope': scope,
|
|
'scope_key': scope_key,
|
|
'state_key': key,
|
|
'value_json': value_json,
|
|
'bot_id': event.bot_id,
|
|
'workspace_id': event.workspace_id,
|
|
'conversation_id': event.conversation_id,
|
|
'thread_id': event.thread_id,
|
|
'actor_type': event.actor.actor_type if event.actor else None,
|
|
'actor_id': event.actor.actor_id if event.actor else None,
|
|
'subject_type': event.subject.subject_type if event.subject else None,
|
|
'subject_id': event.subject.subject_id if event.subject else None,
|
|
'created_at': now,
|
|
'updated_at': now,
|
|
},
|
|
)
|
|
|
|
return True, None
|
|
|
|
async def state_get(
|
|
self,
|
|
scope_key: str,
|
|
state_key: str,
|
|
) -> typing.Any:
|
|
"""Get a single state value by scope_key and state_key.
|
|
|
|
Used by State API handlers.
|
|
"""
|
|
state_key = normalize_state_key(state_key)
|
|
|
|
async with self._db_engine.connect() as conn:
|
|
result = await conn.execute(
|
|
select(AgentRunnerState.value_json)
|
|
.where(AgentRunnerState.scope_key == scope_key)
|
|
.where(AgentRunnerState.state_key == state_key)
|
|
)
|
|
row = result.first()
|
|
|
|
if not row or not row.value_json:
|
|
return None
|
|
|
|
try:
|
|
return json.loads(row.value_json)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
async def state_set(
|
|
self,
|
|
scope_key: str,
|
|
state_key: str,
|
|
value: typing.Any,
|
|
runner_id: str,
|
|
binding_identity: str,
|
|
scope: str,
|
|
context: dict[str, typing.Any] | None = None,
|
|
logger: typing.Any = None,
|
|
) -> tuple[bool, str | None]:
|
|
"""Set a state value.
|
|
|
|
Used by State API handlers.
|
|
Context contains optional fields like bot_id, conversation_id, etc.
|
|
"""
|
|
state_key = normalize_state_key(state_key)
|
|
|
|
# Validate and serialize value
|
|
value_json, error = self._validate_json_value(value, logger)
|
|
if error:
|
|
return False, error
|
|
|
|
context = context or {}
|
|
|
|
now = datetime.utcnow()
|
|
async with self._db_engine.begin() as conn:
|
|
await self._upsert_state_row(
|
|
conn,
|
|
{
|
|
'runner_id': runner_id,
|
|
'binding_identity': binding_identity,
|
|
'scope': scope,
|
|
'scope_key': scope_key,
|
|
'state_key': state_key,
|
|
'value_json': value_json,
|
|
'bot_id': context.get('bot_id'),
|
|
'workspace_id': context.get('workspace_id'),
|
|
'conversation_id': context.get('conversation_id'),
|
|
'thread_id': context.get('thread_id'),
|
|
'actor_type': context.get('actor_type'),
|
|
'actor_id': context.get('actor_id'),
|
|
'subject_type': context.get('subject_type'),
|
|
'subject_id': context.get('subject_id'),
|
|
'created_at': now,
|
|
'updated_at': now,
|
|
},
|
|
)
|
|
|
|
return True, None
|
|
|
|
async def state_delete(
|
|
self,
|
|
scope_key: str,
|
|
state_key: str,
|
|
) -> bool:
|
|
"""Delete a state value.
|
|
|
|
Returns True if deleted, False if not found.
|
|
"""
|
|
state_key = normalize_state_key(state_key)
|
|
|
|
async with self._db_engine.begin() as conn:
|
|
result = await conn.execute(
|
|
delete(AgentRunnerState)
|
|
.where(AgentRunnerState.scope_key == scope_key)
|
|
.where(AgentRunnerState.state_key == state_key)
|
|
)
|
|
return (result.rowcount or 0) > 0
|
|
|
|
async def state_list(
|
|
self,
|
|
scope_key: str,
|
|
prefix: str | None = None,
|
|
limit: int = 100,
|
|
) -> tuple[list[str], bool]:
|
|
"""List state keys in a scope.
|
|
|
|
Returns tuple of (keys, has_more).
|
|
"""
|
|
# Enforce limit cap
|
|
limit = min(limit, 100)
|
|
|
|
async with self._db_engine.connect() as conn:
|
|
query = (
|
|
select(AgentRunnerState.state_key)
|
|
.where(AgentRunnerState.scope_key == scope_key)
|
|
.order_by(AgentRunnerState.state_key)
|
|
.limit(limit + 1) # Fetch one extra to check has_more
|
|
)
|
|
|
|
if prefix:
|
|
prefix = normalize_state_key(prefix)
|
|
query = query.where(
|
|
AgentRunnerState.state_key.like(f'{prefix}%')
|
|
)
|
|
|
|
result = await conn.execute(query)
|
|
rows = result.fetchall()
|
|
|
|
keys = [row.state_key for row in rows[:limit]]
|
|
has_more = len(rows) > limit
|
|
|
|
return keys, has_more
|
|
|
|
async def clear_all(self) -> None:
|
|
"""Clear all state entries (for testing)."""
|
|
async with self._db_engine.begin() as conn:
|
|
await conn.execute(delete(AgentRunnerState))
|
|
|
|
|
|
# Global singleton persistent state store
|
|
_persistent_state_store: PersistentStateStore | None = None
|
|
_persistent_state_store_lock = threading.Lock()
|
|
|
|
|
|
def get_persistent_state_store(db_engine: AsyncEngine | None = None) -> PersistentStateStore:
|
|
"""Get the global persistent state store singleton.
|
|
|
|
Args:
|
|
db_engine: Database engine (required on first call)
|
|
|
|
Returns:
|
|
PersistentStateStore singleton
|
|
"""
|
|
global _persistent_state_store
|
|
with _persistent_state_store_lock:
|
|
if _persistent_state_store is None:
|
|
if db_engine is None:
|
|
raise RuntimeError("db_engine required for first call to get_persistent_state_store")
|
|
_persistent_state_store = PersistentStateStore(db_engine)
|
|
return _persistent_state_store
|
|
|
|
|
|
def reset_persistent_state_store() -> None:
|
|
"""Reset the global persistent state store (for testing)."""
|
|
global _persistent_state_store
|
|
with _persistent_state_store_lock:
|
|
_persistent_state_store = None
|