mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 06:46:02 +00:00
feat(agent-runner): add persistent state APIs
This commit is contained in:
361
tests/unit_tests/agent/test_context_builder_state.py
Normal file
361
tests/unit_tests/agent/test_context_builder_state.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Tests for ContextAccess.state determination in AgentRunContextBuilder.
|
||||
|
||||
Tests focus on:
|
||||
- Event-first mode: state=True when enable_state=True and state_scopes non-empty
|
||||
- Event-first mode: state=False when enable_state=False
|
||||
- Legacy Query mode: state=False (no persistent state API)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application for testing."""
|
||||
def __init__(self):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock()
|
||||
|
||||
|
||||
class TestContextAccessStateDetermination:
|
||||
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
return AgentEventEnvelope(
|
||||
event_id='evt_001',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id='conv_001',
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_descriptor(self):
|
||||
"""Create mock runner descriptor."""
|
||||
descriptor = MagicMock()
|
||||
descriptor.id = 'plugin:test/runner/default'
|
||||
descriptor.protocol_version = '1.0'
|
||||
descriptor.permissions = {}
|
||||
return descriptor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=True when enable_state=True and state_scopes non-empty."""
|
||||
# Create binding with state enabled and non-empty scopes
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call to _build_context_access
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=True based on binding.state_policy
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_false_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=False."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_empty_scopes_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=True but state_scopes empty."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=[], # Empty scopes - state not available
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False (empty scopes means state not available)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_binding_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when binding is None (legacy mode)."""
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call without binding
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding=None)
|
||||
|
||||
# Verify state=False (no binding = no state policy = state disabled)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_scope_available_without_conversation(self, mock_app, mock_descriptor):
|
||||
"""State API with runner scope is available even without conversation_id."""
|
||||
mock_event = AgentEventEnvelope(
|
||||
event_id='evt_002',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id=None, # No conversation
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_002',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='workspace', scope_id='ws_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['runner'], # Runner scope doesn't need conversation_id
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True because runner scope is enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_scopes_all_available(self, mock_app, mock_event, mock_descriptor):
|
||||
"""State API with multiple scopes enabled."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_003',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True with all scopes enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
|
||||
class TestStatePolicyFromBinding:
|
||||
"""Tests for state_policy extraction from binding."""
|
||||
|
||||
def test_state_policy_structure(self):
|
||||
"""State policy has correct structure."""
|
||||
policy = StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
)
|
||||
|
||||
assert policy.enable_state is True
|
||||
assert len(policy.state_scopes) == 4
|
||||
assert 'conversation' in policy.state_scopes
|
||||
|
||||
def test_state_policy_disabled(self):
|
||||
"""State policy can be disabled."""
|
||||
policy = StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
)
|
||||
|
||||
assert policy.enable_state is False
|
||||
assert len(policy.state_scopes) == 0
|
||||
|
||||
|
||||
class TestBindingWithStatePolicy:
|
||||
"""Tests for binding with state_policy."""
|
||||
|
||||
def test_binding_contains_state_policy(self):
|
||||
"""Binding contains state_policy field."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation'],
|
||||
),
|
||||
)
|
||||
|
||||
assert binding.state_policy is not None
|
||||
assert binding.state_policy.enable_state is True
|
||||
|
||||
|
||||
class TestContextAccessOtherAPIs:
|
||||
"""Tests for other available_apis fields based on permissions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_apis_based_on_permissions(self, mock_app):
|
||||
"""History APIs availability based on runner permissions."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'history': ['page', 'search'],
|
||||
}
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# History APIs enabled based on permissions
|
||||
assert context_access['available_apis']['history_page'] is True
|
||||
assert context_access['available_apis']['history_search'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_apis_based_on_permissions(self, mock_app):
|
||||
"""Event APIs availability based on runner permissions."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'events': ['get', 'page'],
|
||||
}
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Event APIs enabled based on permissions
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_artifact_apis_based_on_permissions(self, mock_app):
|
||||
"""Artifact APIs availability based on runner permissions."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'artifacts': ['metadata', 'read'],
|
||||
}
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Artifact APIs enabled based on permissions
|
||||
assert context_access['available_apis']['artifact_metadata'] is True
|
||||
assert context_access['available_apis']['artifact_read'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_permissions_all_apis_disabled(self, mock_app):
|
||||
"""All pull APIs disabled when permissions are empty."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {} # No permissions
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='pipeline', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# All pull APIs should be disabled
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is False
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['artifact_metadata'] is False
|
||||
assert context_access['available_apis']['artifact_read'] is False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
@@ -113,13 +113,24 @@ class TestContextValidation:
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Build context
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
# Build context
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Validate it can be parsed by SDK AgentRunContext
|
||||
# This will raise ValidationError if invalid
|
||||
@@ -162,12 +173,23 @@ class TestContextValidation:
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Protocol v1 does NOT have these as core fields
|
||||
assert 'messages' not in context_dict, "messages should not be top-level in Protocol v1"
|
||||
@@ -192,12 +214,23 @@ class TestContextValidation:
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# event is REQUIRED in Protocol v1
|
||||
assert context_dict.get('event') is not None, "event is REQUIRED for Protocol v1"
|
||||
@@ -217,12 +250,23 @@ class TestContextValidation:
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# delivery is REQUIRED in Protocol v1
|
||||
assert context_dict.get('delivery') is not None, "delivery is REQUIRED for Protocol v1"
|
||||
|
||||
538
tests/unit_tests/agent/test_state_api_auth.py
Normal file
538
tests/unit_tests/agent/test_state_api_auth.py
Normal file
@@ -0,0 +1,538 @@
|
||||
"""Tests for State API handler authorization in RuntimeConnectionHandler.
|
||||
|
||||
Tests focus on:
|
||||
- STATE_GET authorization
|
||||
- STATE_SET authorization
|
||||
- STATE_DELETE authorization
|
||||
- STATE_LIST authorization
|
||||
|
||||
These tests instantiate real RuntimeConnectionHandler action handlers and verify:
|
||||
- Authorization errors for missing/mismatched caller_plugin_identity
|
||||
- Authorization errors for disabled state or scope
|
||||
- Full flow: set -> get -> list -> delete with real SQLite
|
||||
|
||||
Authorization rules:
|
||||
- caller_plugin_identity is REQUIRED when session has plugin_identity
|
||||
- caller_plugin_identity must match session's plugin_identity
|
||||
- enable_state must be True
|
||||
- scope must be in state_scopes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry, get_session_registry
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore, reset_persistent_state_store
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.runtime.io.connection import Connection
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
# Import shared test fixtures
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""Fake connection for testing."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self, db_engine=None):
|
||||
self.logger = MagicMock()
|
||||
self.logger.debug = MagicMock()
|
||||
self.logger.warning = MagicMock()
|
||||
self.logger.error = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry():
|
||||
"""Create a fresh session registry for each test."""
|
||||
return AgentRunSessionRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(db_engine):
|
||||
"""Create a persistent state store with real SQLite."""
|
||||
reset_persistent_state_store()
|
||||
store = PersistentStateStore(db_engine)
|
||||
|
||||
# Create the table
|
||||
from langbot.pkg.entity.persistence.agent_runner_state import AgentRunnerState
|
||||
from sqlalchemy import text
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.run_sync(AgentRunnerState.__table__.create, checkfirst=True)
|
||||
|
||||
yield store
|
||||
reset_persistent_state_store()
|
||||
|
||||
|
||||
class TestStateAPIHandlerAuthorization:
|
||||
"""Tests for State API handler authorization with real action calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_run_id_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing run_id returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Get the STATE_GET action handler (actions dict is keyed by action value string)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without run_id
|
||||
result = await state_get_handler({'scope': 'conversation', 'key': 'test_key'})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'run_id is required' in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_run_not_found_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: run_id not in session registry returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with non-existent run_id
|
||||
result = await state_get_handler({
|
||||
'run_id': 'nonexistent_run',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not found' in result.message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_caller_plugin_identity_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing caller_plugin_identity when session has plugin_identity returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session with plugin_identity
|
||||
await session_registry.register(
|
||||
run_id='run_test_missing_identity',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_missing_identity',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'caller_plugin_identity is required' in result.message
|
||||
|
||||
await session_registry.unregister('run_test_missing_identity')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_caller_identity_mismatch_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: caller_plugin_identity mismatch returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_mismatch',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with wrong caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_mismatch',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'other/plugin',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'mismatch' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_mismatch')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_enable_state_false_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: enable_state=False returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': False, 'state_scopes': []},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_disabled',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_scope_not_enabled_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: scope not in state_scopes returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_scope_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key', 'actor': 'actor_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Request 'actor' scope which is not in state_scopes
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_scope_disabled',
|
||||
'scope': 'actor',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not enabled' in result.message.lower() or 'scope' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_scope_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_scope_key_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing scope_key in state_context returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_no_scope_key',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, # No scope_keys
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_no_scope_key',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not available' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_no_scope_key')
|
||||
|
||||
|
||||
class TestStateAPIFullFlowWithRealDB:
|
||||
"""Tests for complete State API flow with real SQLite database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_get_list_delete_flow(self, session_registry, db_engine, persistent_store):
|
||||
"""Test complete state flow: set -> get -> list -> delete with real SQLite."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session
|
||||
await session_registry.register(
|
||||
run_id='run_full_flow',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation', 'runner']},
|
||||
state_context={
|
||||
'scope_keys': {
|
||||
'conversation': 'conv:test_runner:binding_1:conv_123',
|
||||
'runner': 'runner:test_runner:binding_1',
|
||||
},
|
||||
'binding_identity': 'binding_1',
|
||||
'conversation_id': 'conv_123',
|
||||
},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Verify session has correct state_context
|
||||
session = await session_registry.get('run_full_flow')
|
||||
assert session is not None
|
||||
state_ctx = session.get('state_context')
|
||||
assert state_ctx is not None, f"state_context is None. Session keys: {list(session.keys())}"
|
||||
assert 'scope_keys' in state_ctx, f"scope_keys not in state_context: {state_ctx}"
|
||||
assert 'conversation' in state_ctx['scope_keys'], f"conversation not in scope_keys: {state_ctx['scope_keys']}"
|
||||
|
||||
# Get handlers (actions dict is keyed by action value string)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
state_list_handler = handler.actions[PluginToRuntimeAction.STATE_LIST.value]
|
||||
state_delete_handler = handler.actions[PluginToRuntimeAction.STATE_DELETE.value]
|
||||
|
||||
# 1. STATE_SET
|
||||
set_result = await state_set_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'value': {'data': 'test_value'},
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert set_result.code == 0
|
||||
assert set_result.data.get('success') is True
|
||||
|
||||
# 2. STATE_GET
|
||||
get_result = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_result.code == 0
|
||||
assert get_result.data.get('value') == {'data': 'test_value'}
|
||||
|
||||
# 3. STATE_LIST
|
||||
list_result = await state_list_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'prefix': 'external.',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert list_result.code == 0
|
||||
keys = list_result.data.get('keys', [])
|
||||
assert 'external.test_key' in keys
|
||||
|
||||
# 4. STATE_DELETE
|
||||
delete_result = await state_delete_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert delete_result.code == 0
|
||||
|
||||
# 5. Verify deleted
|
||||
get_after_delete = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_after_delete.code == 0
|
||||
assert get_after_delete.data.get('value') is None
|
||||
|
||||
await session_registry.unregister('run_full_flow')
|
||||
|
||||
|
||||
class TestStateHandlerReadsFromSessionTopLevel:
|
||||
"""Tests verifying handlers read state_policy/state_context from session top-level, not resources."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_policy_from_session_top_level(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_policy from session['state_policy'], not session['resources']['state_policy']."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_policy at top level
|
||||
await session_registry.register(
|
||||
run_id='run_policy_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': False, 'state_scopes': []}, # Disabled at top level
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_policy
|
||||
session = await session_registry.get('run_policy_top_level')
|
||||
assert session is not None
|
||||
assert 'state_policy' not in session.get('resources', {}), \
|
||||
"resources should NOT contain state_policy"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Should fail because enable_state=False in session['state_policy']
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_policy_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_policy_top_level')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_context_from_session_top_level(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_context from session['state_context'], not session['resources']['state_context']."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_context at top level
|
||||
await session_registry.register(
|
||||
run_id='run_context_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key_xyz'}, 'binding_identity': 'binding_xyz'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_context
|
||||
session = await session_registry.get('run_context_top_level')
|
||||
assert session is not None
|
||||
assert 'state_context' not in session.get('resources', {}), \
|
||||
"resources should NOT contain state_context"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
|
||||
# Should use scope_key from session['state_context']['scope_keys']['conversation']
|
||||
result = await state_set_handler({
|
||||
'run_id': 'run_context_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'value': 'test_value',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
# Should succeed - scope_key was found in state_context
|
||||
assert result.code == 0
|
||||
|
||||
await session_registry.unregister('run_context_top_level')
|
||||
|
||||
|
||||
class TestResourcesDoesNotContainStateMetadata:
|
||||
"""Tests verifying resources is clean - no state metadata mixed in."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resources_clean_after_register(self, session_registry):
|
||||
"""After register(), resources should not contain state_policy or state_context."""
|
||||
resources = make_resources()
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_resources_clean',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=resources,
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
session = await session_registry.get('run_resources_clean')
|
||||
assert session is not None
|
||||
|
||||
# Verify resources is clean
|
||||
session_resources = session.get('resources', {})
|
||||
assert 'state_policy' not in session_resources, \
|
||||
"session['resources'] should NOT contain state_policy"
|
||||
assert 'state_context' not in session_resources, \
|
||||
"session['resources'] should NOT contain state_context"
|
||||
|
||||
# Verify state metadata is at top level
|
||||
assert 'state_policy' in session
|
||||
assert 'state_context' in session
|
||||
|
||||
await session_registry.unregister('run_resources_clean')
|
||||
@@ -1137,4 +1137,238 @@ class TestStateStorePolicyEnforcement:
|
||||
)
|
||||
|
||||
assert result is False
|
||||
assert any('not enabled' in w for w in logger.warnings)
|
||||
assert any('not enabled' in w for w in logger.warnings)
|
||||
|
||||
|
||||
# ========== Persistent State Store Tests ==========
|
||||
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import tempfile
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
|
||||
|
||||
class TestPersistentStateStore:
|
||||
"""Tests for persistent database-backed state store."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create a temporary async SQLite database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
# Create tables
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
await engine.dispose()
|
||||
os.unlink(db_path)
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(self, db_engine):
|
||||
"""Create a persistent state store for testing."""
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
|
||||
store = PersistentStateStore(db_engine)
|
||||
yield store
|
||||
await store.clear_all()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_snapshot_empty(self, persistent_store):
|
||||
"""Building snapshot from empty store returns empty scopes."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
|
||||
assert snapshot['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
assert snapshot['actor'] == {}
|
||||
assert snapshot['subject'] == {}
|
||||
assert snapshot['runner'] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_and_get(self, persistent_store):
|
||||
"""State set/get round trip."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
# Set state
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'test_key', {'nested': 'value'}, None
|
||||
)
|
||||
assert success is True
|
||||
assert error is None
|
||||
|
||||
# Get via snapshot
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_binding_isolation(self, persistent_store):
|
||||
"""Different binding_id should have isolated state."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding_a = FakeBinding(binding_id='binding_a')
|
||||
binding_b = FakeBinding(binding_id='binding_b')
|
||||
|
||||
# Set for binding_a
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding_a, descriptor, 'conversation', 'key', 'value_a', None
|
||||
)
|
||||
|
||||
# binding_b should not see binding_a's state
|
||||
snapshot_b = await persistent_store.build_snapshot_from_event(event, binding_b, descriptor)
|
||||
assert snapshot_b['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
|
||||
# binding_a should see its own state
|
||||
snapshot_a = await persistent_store.build_snapshot_from_event(event, binding_a, descriptor)
|
||||
assert snapshot_a['conversation']['key'] == 'value_a'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_disable_state(self, persistent_store):
|
||||
"""enable_state=False should return empty snapshot and reject updates."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
policy = StatePolicy(enable_state=False)
|
||||
binding = FakeBinding(state_policy=policy)
|
||||
|
||||
# Snapshot should be empty
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot == {'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}}
|
||||
|
||||
# Update should be rejected
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
assert success is False
|
||||
assert 'disabled' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_scope_restriction(self, persistent_store):
|
||||
"""state_scopes should restrict which scopes are accessible."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
)
|
||||
policy = StatePolicy(state_scopes=['conversation']) # Only conversation
|
||||
binding = FakeBinding(state_policy=policy)
|
||||
|
||||
# Conversation should work
|
||||
success_conv, _ = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value_conv', None
|
||||
)
|
||||
assert success_conv is True
|
||||
|
||||
# Actor should be rejected
|
||||
success_actor, error_actor = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'actor', 'key', 'value_actor', None
|
||||
)
|
||||
assert success_actor is False
|
||||
assert 'not enabled' in error_actor.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_json_size_limit(self, persistent_store):
|
||||
"""Value exceeding size limit should be rejected."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
# Create a large value (> 256KB)
|
||||
large_value = 'x' * (300 * 1024)
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', large_value, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'exceeds limit' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_not_json_serializable(self, persistent_store):
|
||||
"""Non-JSON-serializable value should be rejected."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
# Create a non-serializable value (set is not JSON-serializable)
|
||||
non_serializable = {'key': {1, 2, 3}}
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', non_serializable, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'json' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_list(self, persistent_store):
|
||||
"""State list should return keys with optional prefix filter."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
# Set multiple keys
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.id', '123', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.name', 'test', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'memory.key', 'value', None
|
||||
)
|
||||
|
||||
# Build scope key for list
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
|
||||
temp_store = PersistentStateStore(None)
|
||||
scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor)
|
||||
|
||||
# List all keys
|
||||
keys, has_more = await persistent_store.state_list(scope_key)
|
||||
assert len(keys) == 3
|
||||
assert has_more is False
|
||||
|
||||
# List with prefix
|
||||
keys_ext, _ = await persistent_store.state_list(scope_key, prefix='external.')
|
||||
assert len(keys_ext) == 2
|
||||
assert 'external.id' in keys_ext
|
||||
assert 'external.name' in keys_ext
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_delete(self, persistent_store):
|
||||
"""State delete should remove key."""
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
# Set and verify
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['key'] == 'value'
|
||||
|
||||
# Build scope key for delete
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
|
||||
temp_store = PersistentStateStore(None)
|
||||
scope_key = temp_store._make_conversation_scope_key(event, binding, descriptor)
|
||||
|
||||
# Delete
|
||||
deleted = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted is True
|
||||
|
||||
# Verify deleted
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert 'key' not in snapshot['conversation']
|
||||
|
||||
# Delete non-existent should return False
|
||||
deleted_again = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted_again is False
|
||||
Reference in New Issue
Block a user