mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-24 14:34:20 +00:00
feat(agent-runner): add plugin runner host integration
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
"""Tests for agent runner subsystem."""
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Shared test fixtures for agent runner tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
|
||||
def make_resources(
|
||||
models: list[dict] | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
knowledge_bases: list[dict] | None = None,
|
||||
skills: list[dict] | None = None,
|
||||
storage: dict | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create a minimal AgentResources dict for testing.
|
||||
|
||||
Args:
|
||||
models: List of model dicts with 'model_id' key
|
||||
tools: List of tool dicts with 'tool_name' key
|
||||
knowledge_bases: List of KB dicts with 'kb_id' key
|
||||
skills: List of skill dicts with 'skill_name' key
|
||||
storage: Storage permissions dict
|
||||
Returns:
|
||||
AgentResources dict with all required fields
|
||||
"""
|
||||
return {
|
||||
'models': models or [],
|
||||
'tools': tools or [],
|
||||
'knowledge_bases': knowledge_bases or [],
|
||||
'skills': skills or [],
|
||||
'storage': storage or {'plugin_storage': False, 'workspace_storage': False},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
|
||||
def make_session(
|
||||
run_id: str = 'test-run-id',
|
||||
runner_id: str = 'plugin:test/test-runner/default',
|
||||
query_id: int | None = 1,
|
||||
plugin_identity: str = 'test/test-runner',
|
||||
resources: dict | None = None,
|
||||
conversation_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
available_apis: dict[str, bool] | None = None,
|
||||
state_policy: dict[str, typing.Any] | None = None,
|
||||
state_context: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create a minimal AgentRunSession dict for testing.
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier
|
||||
runner_id: Runner descriptor ID
|
||||
query_id: Host entry query ID
|
||||
plugin_identity: Plugin identifier (author/name)
|
||||
resources: AgentResources dict (uses make_resources() default if None)
|
||||
|
||||
Returns:
|
||||
AgentRunSession dict with run-scoped authorization snapshot
|
||||
"""
|
||||
import time
|
||||
now = int(time.time())
|
||||
res = resources if resources is not None else make_resources()
|
||||
apis = available_apis if available_apis is not None else {}
|
||||
policy = (
|
||||
state_policy
|
||||
if state_policy is not None
|
||||
else {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
|
||||
)
|
||||
context = state_context if state_context is not None else {}
|
||||
|
||||
authorized_ids: dict[str, set[str]] = {
|
||||
'model': {m.get('model_id') for m in res.get('models', [])},
|
||||
'tool': {t.get('tool_name') for t in res.get('tools', [])},
|
||||
'knowledge_base': {kb.get('kb_id') for kb in res.get('knowledge_bases', [])},
|
||||
'skill': {s.get('skill_name') for s in res.get('skills', [])},
|
||||
}
|
||||
authorized_operations: dict[str, dict[str, set[str]]] = {
|
||||
'model': {
|
||||
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank'])
|
||||
for m in res.get('models', [])
|
||||
if m.get('model_id')
|
||||
},
|
||||
'tool': {
|
||||
t.get('tool_name'): set(t.get('operations') or ['detail', 'call'])
|
||||
for t in res.get('tools', [])
|
||||
if t.get('tool_name')
|
||||
},
|
||||
'knowledge_base': {
|
||||
kb.get('kb_id'): set(kb.get('operations') or ['list', 'retrieve'])
|
||||
for kb in res.get('knowledge_bases', [])
|
||||
if kb.get('kb_id')
|
||||
},
|
||||
'skill': {
|
||||
s.get('skill_name'): set(s.get('operations') or ['activate'])
|
||||
for s in res.get('skills', [])
|
||||
if s.get('skill_name')
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
'run_id': run_id,
|
||||
'runner_id': runner_id,
|
||||
'query_id': query_id,
|
||||
'plugin_identity': plugin_identity,
|
||||
'authorization': {
|
||||
'resources': res,
|
||||
'available_apis': apis,
|
||||
'conversation_id': conversation_id,
|
||||
'bot_id': bot_id,
|
||||
'workspace_id': workspace_id,
|
||||
'thread_id': thread_id,
|
||||
'state_policy': policy,
|
||||
'state_context': context,
|
||||
'authorized_ids': authorized_ids,
|
||||
'authorized_operations': authorized_operations,
|
||||
},
|
||||
'status': {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
"""Tests for ChatMessageHandler behavior with AgentRunOrchestrator.
|
||||
|
||||
Tests focus on:
|
||||
- Streaming mode behavior (single resp_message_id, pop/append pattern)
|
||||
- Non-streaming mode behavior (no pop)
|
||||
- Orchestrator invocation
|
||||
- Error handling for RunnerNotFoundError, RunnerExecutionError
|
||||
|
||||
Avoids circular imports by using proper import structure.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from langbot.pkg.agent.runner.errors import (
|
||||
RunnerNotFoundError,
|
||||
RunnerExecutionError,
|
||||
RunnerNotAuthorizedError,
|
||||
)
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
# Define mock classes in dependency order (no forward references needed)
|
||||
|
||||
class MockLauncherType:
|
||||
value = 'person'
|
||||
|
||||
|
||||
class MockConversation:
|
||||
def __init__(self):
|
||||
self.uuid = 'conv-uuid'
|
||||
self.messages = []
|
||||
|
||||
|
||||
class MockMessage:
|
||||
role = 'user'
|
||||
content = 'Hello'
|
||||
|
||||
|
||||
class MockAdapter:
|
||||
is_stream = False
|
||||
|
||||
async def is_stream_output_supported(self):
|
||||
return self.is_stream
|
||||
|
||||
async def create_message_card(self, resp_message_id, message_event):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession:
|
||||
launcher_type = MockLauncherType()
|
||||
launcher_id = 'user123'
|
||||
|
||||
def __init__(self):
|
||||
self.using_conversation = MockConversation()
|
||||
|
||||
|
||||
class MockQuery:
|
||||
"""Mock Query for testing."""
|
||||
def __init__(self):
|
||||
self.query_id = 1
|
||||
self.launcher_type = MockLauncherType()
|
||||
self.launcher_id = 'user123'
|
||||
self.sender_id = 'user123'
|
||||
self.bot_uuid = 'bot-uuid'
|
||||
self.pipeline_uuid = 'pipeline-uuid'
|
||||
self.pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
'runner_config': {},
|
||||
},
|
||||
'output': {
|
||||
'misc': {
|
||||
'exception-handling': 'show-hint',
|
||||
'failure-hint': 'Request failed.',
|
||||
},
|
||||
},
|
||||
}
|
||||
self.variables = {}
|
||||
self.session = MockSession()
|
||||
self.user_message = MockMessage()
|
||||
self.messages = []
|
||||
self.resp_messages = []
|
||||
self.resp_message_chain = None
|
||||
self.adapter = MockAdapter()
|
||||
self.message_event = MagicMock()
|
||||
self.message_chain = MagicMock()
|
||||
|
||||
|
||||
class MockMessageChunk:
|
||||
"""Mock MessageChunk for testing."""
|
||||
def __init__(self, content, resp_message_id=None):
|
||||
self.role = 'assistant'
|
||||
self.content = content
|
||||
self.resp_message_id = resp_message_id
|
||||
self.tool_calls = []
|
||||
self.is_final = False
|
||||
|
||||
def readable_str(self):
|
||||
return self.content
|
||||
|
||||
|
||||
class MockEventContext:
|
||||
"""Mock event context for testing."""
|
||||
def __init__(self, prevented=False, reply_message_chain=None, user_message_alter=None):
|
||||
self._prevented = prevented
|
||||
self.event = MagicMock()
|
||||
self.event.reply_message_chain = reply_message_chain
|
||||
self.event.user_message_alter = user_message_alter
|
||||
|
||||
def is_prevented_default(self):
|
||||
return self._prevented
|
||||
|
||||
|
||||
class MockAgentRunOrchestrator:
|
||||
"""Mock AgentRunOrchestrator for testing."""
|
||||
def __init__(self, chunks=None, error=None):
|
||||
self._chunks = chunks or []
|
||||
self._error = error
|
||||
|
||||
async def run_from_query(self, query):
|
||||
"""Async generator that yields chunks or raises error."""
|
||||
if self._error:
|
||||
raise self._error
|
||||
for chunk in self._chunks:
|
||||
yield chunk
|
||||
|
||||
async def try_claim_steering_from_query(self, query):
|
||||
return False
|
||||
|
||||
def resolve_runner_id_for_telemetry(self, query):
|
||||
return 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application for testing."""
|
||||
def __init__(self, orchestrator=None):
|
||||
self.agent_run_orchestrator = orchestrator or MockAgentRunOrchestrator()
|
||||
self.logger = MagicMock()
|
||||
self.logger.info = MagicMock()
|
||||
self.logger.debug = MagicMock()
|
||||
self.logger.warning = MagicMock()
|
||||
self.logger.error = MagicMock()
|
||||
|
||||
# Mock plugin_connector
|
||||
self.plugin_connector = MagicMock()
|
||||
self.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext())
|
||||
|
||||
# Mock telemetry
|
||||
self.telemetry = MagicMock()
|
||||
self.telemetry.start_send_task = AsyncMock()
|
||||
|
||||
# Mock survey
|
||||
self.survey = MagicMock()
|
||||
self.survey.trigger_event = AsyncMock()
|
||||
|
||||
# Mock model_mgr
|
||||
self.model_mgr = MagicMock()
|
||||
self.model_mgr.get_model_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
# Mock sess_mgr
|
||||
self.sess_mgr = MagicMock()
|
||||
self.sess_mgr.get_conversation = AsyncMock()
|
||||
|
||||
|
||||
class TestStreamingBehavior:
|
||||
"""Tests for streaming mode behavior."""
|
||||
|
||||
def test_single_resp_message_id_for_streaming(self):
|
||||
"""Streaming mode should use single resp_message_id for entire response."""
|
||||
# Simulate the streaming logic: resp_message_id created outside loop
|
||||
resp_message_id = uuid.uuid4()
|
||||
|
||||
chunks = ['Hello', ' World', '!']
|
||||
resp_messages = []
|
||||
|
||||
for chunk in chunks:
|
||||
result = MockMessageChunk(chunk)
|
||||
result.resp_message_id = str(resp_message_id)
|
||||
|
||||
# Pop old chunk (streaming behavior)
|
||||
if resp_messages:
|
||||
resp_messages.pop()
|
||||
resp_messages.append(result)
|
||||
|
||||
# All chunks should have same resp_message_id
|
||||
assert len(resp_messages) == 1 # Only last chunk remains after pop/append
|
||||
assert resp_messages[0].resp_message_id == str(resp_message_id)
|
||||
|
||||
def test_pop_before_append_in_streaming(self):
|
||||
"""Streaming mode should pop old chunk before appending new."""
|
||||
resp_message_id = uuid.uuid4()
|
||||
resp_messages = []
|
||||
|
||||
# First chunk - no pop
|
||||
chunk1 = MockMessageChunk('Hello')
|
||||
chunk1.resp_message_id = str(resp_message_id)
|
||||
resp_messages.append(chunk1)
|
||||
assert len(resp_messages) == 1
|
||||
|
||||
# Second chunk - pop first, then append
|
||||
if resp_messages:
|
||||
resp_messages.pop()
|
||||
chunk2 = MockMessageChunk('Hello World')
|
||||
chunk2.resp_message_id = str(resp_message_id)
|
||||
resp_messages.append(chunk2)
|
||||
assert len(resp_messages) == 1
|
||||
assert resp_messages[0].content == 'Hello World'
|
||||
|
||||
def test_non_streaming_no_pop(self):
|
||||
"""Non-streaming mode should NOT pop previous responses."""
|
||||
resp_messages = []
|
||||
|
||||
# First message
|
||||
msg1 = MockMessageChunk('Response 1')
|
||||
resp_messages.append(msg1)
|
||||
assert len(resp_messages) == 1
|
||||
|
||||
# Second message - should NOT pop in non-streaming
|
||||
msg2 = MockMessageChunk('Response 2')
|
||||
resp_messages.append(msg2)
|
||||
assert len(resp_messages) == 2
|
||||
|
||||
|
||||
class TestConfigMigrationInChatHandler:
|
||||
"""Tests for ConfigMigration usage in chat handler context."""
|
||||
|
||||
def test_resolve_runner_id_from_pipeline_config(self):
|
||||
"""Chat handler should use ConfigMigration to resolve runner ID."""
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolve_runner_id_from_old_format(self):
|
||||
"""ConfigMigration resolves old runner aliases for compatibility."""
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for orchestrator error handling."""
|
||||
|
||||
def test_runner_not_found_error_properties(self):
|
||||
"""RunnerNotFoundError should have runner_id property."""
|
||||
error = RunnerNotFoundError('plugin:notexist/unknown/default')
|
||||
assert error.runner_id == 'plugin:notexist/unknown/default'
|
||||
assert 'not found' in str(error)
|
||||
|
||||
def test_runner_execution_error_retryable(self):
|
||||
"""RunnerExecutionError should have retryable property."""
|
||||
error = RunnerExecutionError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'Upstream timeout',
|
||||
retryable=True,
|
||||
)
|
||||
assert error.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert error.retryable is True
|
||||
assert 'timeout' in str(error)
|
||||
|
||||
def test_runner_execution_error_not_retryable(self):
|
||||
"""RunnerExecutionError can be non-retryable."""
|
||||
error = RunnerExecutionError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'Configuration error',
|
||||
retryable=False,
|
||||
)
|
||||
assert error.retryable is False
|
||||
|
||||
def test_runner_not_authorized_error_properties(self):
|
||||
"""RunnerNotAuthorizedError should have bound_plugins property."""
|
||||
error = RunnerNotAuthorizedError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
['langbot/dify-agent'],
|
||||
)
|
||||
assert error.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert error.bound_plugins == ['langbot/dify-agent']
|
||||
|
||||
|
||||
class TestChatHandlerImports:
|
||||
"""Test that chat handler can be imported without circular import."""
|
||||
|
||||
def test_import_chat_handler_module(self):
|
||||
"""Import chat handler module should work."""
|
||||
# This test verifies the import works without circular dependency
|
||||
from langbot.pkg.pipeline.process.handlers import chat
|
||||
assert chat.ChatMessageHandler is not None
|
||||
|
||||
def test_chat_handler_class_exists(self):
|
||||
"""ChatMessageHandler class should be defined."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
assert ChatMessageHandler.__name__ == 'ChatMessageHandler'
|
||||
|
||||
def test_chat_handler_has_handle_method(self):
|
||||
"""ChatMessageHandler should have async generator handle method."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
assert hasattr(ChatMessageHandler, 'handle')
|
||||
# handle returns AsyncGenerator, so check for async generator function
|
||||
import inspect
|
||||
assert inspect.isasyncgenfunction(ChatMessageHandler.handle)
|
||||
|
||||
|
||||
class TestChatHandlerAsyncBehavior:
|
||||
"""Real async tests for ChatMessageHandler.handle() with mocked orchestrator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_single_resp_message_id(self):
|
||||
"""Streaming mode: all chunks should have same resp_message_id."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
# Create chunks for streaming
|
||||
chunks = [
|
||||
MockMessageChunk('Hello'),
|
||||
MockMessageChunk('Hello World'),
|
||||
MockMessageChunk('Hello World!'),
|
||||
]
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(chunks=chunks)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
|
||||
# Mock event context to not prevent default
|
||||
event_ctx = MockEventContext(prevented=False)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = True # Enable streaming mode
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
# Mock event creation and StageProcessResult to bypass pydantic validation
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Verify single resp_message_id
|
||||
resp_ids = [msg.resp_message_id for msg in query.resp_messages if hasattr(msg, 'resp_message_id')]
|
||||
assert len(set(resp_ids)) == 1 # All same ID
|
||||
|
||||
# Verify pop/append pattern: only last chunk remains
|
||||
assert len(query.resp_messages) == 1
|
||||
assert query.resp_messages[0].content == 'Hello World!'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_no_pop(self):
|
||||
"""Non-streaming mode: all chunks should remain."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
chunks = [
|
||||
MockMessageChunk('Response 1'),
|
||||
MockMessageChunk('Response 2'),
|
||||
]
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(chunks=chunks)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = False # Disable streaming mode
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# No pop: all chunks should remain
|
||||
assert len(query.resp_messages) == 2
|
||||
assert query.resp_messages[0].content == 'Response 1'
|
||||
assert query.resp_messages[1].content == 'Response 2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turn_recreates_conversation_if_tool_resets_it(self):
|
||||
"""Agent turn bookkeeping should tolerate CREATE_NEW_CONVERSATION during runner execution."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
response = MockMessageChunk('Tool response')
|
||||
new_conversation = MockConversation()
|
||||
|
||||
class ResetConversationOrchestrator(MockAgentRunOrchestrator):
|
||||
async def run_from_query(self, query):
|
||||
query.session.using_conversation = None
|
||||
yield response
|
||||
|
||||
mock_ap = MockApplication(orchestrator=ResetConversationOrchestrator())
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
mock_ap.sess_mgr.get_conversation = AsyncMock(return_value=new_conversation)
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = False
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
mock_ap.sess_mgr.get_conversation.assert_awaited_once()
|
||||
assert query.session.using_conversation is new_conversation
|
||||
assert new_conversation.messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_not_found_error(self):
|
||||
"""Handler should catch RunnerNotFoundError and return INTERRUPT."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerNotFoundError('plugin:notexist/unknown/default')
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Should return INTERRUPT with user_notice
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'not found' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_not_authorized_error(self):
|
||||
"""Handler should catch RunnerNotAuthorizedError and return INTERRUPT."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerNotAuthorizedError('plugin:langbot/local-agent/default', ['other/plugin'])
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'not authorized' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_execution_error_retryable(self):
|
||||
"""Handler should catch retryable RunnerExecutionError."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerExecutionError('plugin:langbot/local-agent/default', 'timeout', retryable=True)
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'temporarily unavailable' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevented_default_with_reply(self):
|
||||
"""When event prevented default with reply, use reply message."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
# Mock reply message chain
|
||||
reply_chain = MockMessageChunk('Reply from plugin')
|
||||
|
||||
mock_ap = MockApplication()
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(
|
||||
return_value=MockEventContext(prevented=True, reply_message_chain=reply_chain)
|
||||
)
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Should return CONTINUE with reply message
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert len(query.resp_messages) == 1
|
||||
@@ -0,0 +1,257 @@
|
||||
"""Tests for current AgentRunner config helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
class TestResolveRunnerId:
|
||||
"""Tests for ConfigMigration.resolve_runner_id."""
|
||||
|
||||
def test_resolve_current_runner_id(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolves_old_runner_field(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolves_deerflow_and_weknora_legacy_runner_fields(self):
|
||||
assert (
|
||||
ConfigMigration.resolve_runner_id(
|
||||
{
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'deerflow-api',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
== 'plugin:langbot/deerflow-agent/default'
|
||||
)
|
||||
assert (
|
||||
ConfigMigration.resolve_runner_id(
|
||||
{
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'weknora-api',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
== 'plugin:langbot/weknora-agent/default'
|
||||
)
|
||||
|
||||
def test_resolve_no_runner_config(self):
|
||||
runner_id = ConfigMigration.resolve_runner_id({})
|
||||
assert runner_id is None
|
||||
|
||||
|
||||
class TestResolveRunnerConfig:
|
||||
"""Tests for ConfigMigration.resolve_runner_config."""
|
||||
|
||||
def test_resolve_current_config(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': 'uuid-123',
|
||||
'custom_option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': 'uuid-123', 'custom_option': 10}
|
||||
|
||||
def test_reads_old_runner_block(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'local-agent': {
|
||||
'model': 'uuid-123',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': {'primary': 'uuid-123', 'fallbacks': []}}
|
||||
|
||||
def test_reads_deerflow_and_weknora_legacy_runner_blocks(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'deerflow-api': {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
},
|
||||
'weknora-api': {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deerflow_config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/deerflow-agent/default',
|
||||
)
|
||||
weknora_config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/weknora-agent/default',
|
||||
)
|
||||
|
||||
assert deerflow_config == {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
}
|
||||
assert weknora_config == {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
}
|
||||
|
||||
def test_resolve_no_config(self):
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
{},
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestGetExpireTime:
|
||||
"""Tests for ConfigMigration.get_expire_time."""
|
||||
|
||||
def test_get_expire_time_zero(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'expire-time': 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expire_time = ConfigMigration.get_expire_time(pipeline_config)
|
||||
assert expire_time == 0
|
||||
|
||||
def test_get_expire_time_positive(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'expire-time': 3600,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expire_time = ConfigMigration.get_expire_time(pipeline_config)
|
||||
assert expire_time == 3600
|
||||
|
||||
def test_get_expire_time_default(self):
|
||||
expire_time = ConfigMigration.get_expire_time({})
|
||||
assert expire_time == 0
|
||||
|
||||
|
||||
class TestNormalizePipelineConfig:
|
||||
"""Tests for ConfigMigration.migrate_pipeline_config."""
|
||||
|
||||
def test_normalizes_current_containers(self):
|
||||
config = {'ai': {}}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated == {'ai': {'runner': {}, 'runner_config': {}}}
|
||||
|
||||
def test_preserves_current_config(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:test/my-runner/default'},
|
||||
'runner_config': {
|
||||
'plugin:test/my-runner/default': {'custom-option': 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:test/my-runner/default'
|
||||
assert migrated['ai']['runner_config']['plugin:test/my-runner/default']['custom-option'] == 20
|
||||
|
||||
def test_migrates_old_runner_blocks(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'old-model', 'knowledge-base': 'kb_1'},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/local-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'local-agent' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default'] == {
|
||||
'model': {'primary': 'old-model', 'fallbacks': []},
|
||||
'knowledge-bases': ['kb_1'],
|
||||
}
|
||||
|
||||
def test_migrates_deerflow_legacy_runner_block(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'deerflow-api'},
|
||||
'deerflow-api': {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/deerflow-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'deerflow-api' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/deerflow-agent/default'] == {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
}
|
||||
|
||||
def test_migrates_weknora_legacy_runner_block(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'weknora-api'},
|
||||
'weknora-api': {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/weknora-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'weknora-api' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/weknora-agent/default'] == {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests for persisted AgentRunner config shape."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
class TestMigratePipelineConfig:
|
||||
"""Tests for ConfigMigration.migrate_pipeline_config."""
|
||||
|
||||
def test_current_format_config_stays_unchanged(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
'expire-time': 0,
|
||||
},
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': {'primary': '', 'fallbacks': []},
|
||||
'custom-option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/local-agent/default'
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['custom-option'] == 10
|
||||
|
||||
def test_old_runner_field_is_mapped(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
'expire-time': 3600,
|
||||
},
|
||||
'local-agent': {
|
||||
'model': 'old-model',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner'] == {
|
||||
'expire-time': 3600,
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
}
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default'] == {
|
||||
'model': {'primary': 'old-model', 'fallbacks': []},
|
||||
}
|
||||
assert 'local-agent' not in migrated['ai']
|
||||
|
||||
def test_empty_config_is_unchanged(self):
|
||||
config = {}
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
assert migrated == {}
|
||||
|
||||
def test_config_without_ai_section_is_unchanged(self):
|
||||
config = {'trigger': {}}
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
assert migrated == {'trigger': {}}
|
||||
|
||||
|
||||
class TestDefaultPipelineConfig:
|
||||
"""Tests for default-pipeline-config.json format."""
|
||||
|
||||
def test_default_config_is_current_format(self):
|
||||
from langbot.pkg.utils import paths as path_utils
|
||||
|
||||
template_path = path_utils.get_resource_path('templates/default-pipeline-config.json')
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert 'ai' in config
|
||||
assert 'runner' in config['ai']
|
||||
assert 'id' in config['ai']['runner']
|
||||
assert config['ai']['runner']['id'] == ''
|
||||
assert 'runner_config' in config['ai']
|
||||
assert config['ai']['runner_config'] == {}
|
||||
assert 'local-agent' not in config['ai']
|
||||
|
||||
|
||||
class TestResolveRunnerId:
|
||||
"""Tests for current runner id resolution."""
|
||||
|
||||
def test_resolve_current_id(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:test/my-runner/default'},
|
||||
},
|
||||
}
|
||||
runner_id = ConfigMigration.resolve_runner_id(config)
|
||||
assert runner_id == 'plugin:test/my-runner/default'
|
||||
|
||||
def test_old_runner_field_is_mapped(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
},
|
||||
}
|
||||
runner_id = ConfigMigration.resolve_runner_id(config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class TestResolveRunnerConfig:
|
||||
"""Tests for runtime runner config resolution."""
|
||||
|
||||
def test_resolve_current_config(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {'custom-option': 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config['custom-option'] == 20
|
||||
|
||||
def test_old_runner_block_is_read(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'local-agent': {'custom-option': 20},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config == {'custom-option': 20}
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for Query entry adapter params packaging."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
|
||||
|
||||
class TestBuildParams:
|
||||
"""Tests for QueryEntryAdapter.build_params filtering."""
|
||||
|
||||
def test_params_empty_when_no_variables(self):
|
||||
query = type('Query', (), {'variables': None})()
|
||||
assert QueryEntryAdapter.build_params(query) == {}
|
||||
|
||||
def test_params_filters_underscore_prefix(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'_internal_var': 'should_be_excluded',
|
||||
'_pipeline_bound_plugins': ['a/b'],
|
||||
'_monitoring_bot_name': 'Bot',
|
||||
'public_var': 'should_be_included',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert '_internal_var' not in params
|
||||
assert '_pipeline_bound_plugins' not in params
|
||||
assert '_monitoring_bot_name' not in params
|
||||
assert params['public_var'] == 'should_be_included'
|
||||
|
||||
def test_params_filters_sensitive_naming(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'api_key': 'secret123',
|
||||
'API_KEY': 'secret456',
|
||||
'token': 'tok123',
|
||||
'secret': 'sec123',
|
||||
'password': 'pass123',
|
||||
'credential': 'cred123',
|
||||
'user_api_key': 'should_be_excluded',
|
||||
'user_secret_key': 'should_be_excluded',
|
||||
'my_token_value': 'should_be_excluded',
|
||||
'user_password_hash': 'should_be_excluded',
|
||||
'public_name': 'should_be_included',
|
||||
'safe_value': 'should_be_included',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'api_key' not in params
|
||||
assert 'API_KEY' not in params
|
||||
assert 'token' not in params
|
||||
assert 'secret' not in params
|
||||
assert 'password' not in params
|
||||
assert 'credential' not in params
|
||||
assert 'user_api_key' not in params
|
||||
assert 'user_secret_key' not in params
|
||||
assert 'my_token_value' not in params
|
||||
assert 'user_password_hash' not in params
|
||||
assert 'public_name' in params
|
||||
assert 'safe_value' in params
|
||||
|
||||
def test_params_keeps_common_public_vars(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'launcher_type': 'telegram',
|
||||
'launcher_id': 'group_123',
|
||||
'sender_id': 'user_001',
|
||||
'session_id': 'sess_abc',
|
||||
'msg_create_time': 1234567890,
|
||||
'group_name': 'Tech Group',
|
||||
'sender_name': 'John',
|
||||
'user_message_text': 'Hello world',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert params['launcher_type'] == 'telegram'
|
||||
assert params['launcher_id'] == 'group_123'
|
||||
assert params['sender_id'] == 'user_001'
|
||||
assert params['session_id'] == 'sess_abc'
|
||||
assert params['msg_create_time'] == 1234567890
|
||||
assert params['group_name'] == 'Tech Group'
|
||||
assert params['sender_name'] == 'John'
|
||||
assert params['user_message_text'] == 'Hello world'
|
||||
|
||||
def test_params_filters_non_json_serializable(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'string_value': 'hello',
|
||||
'int_value': 42,
|
||||
'float_value': 3.14,
|
||||
'bool_value': True,
|
||||
'null_value': None,
|
||||
'list_value': ['a', 'b', 'c'],
|
||||
'dict_value': {'nested': 'value'},
|
||||
'custom_object': CustomObject(),
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'string_value' in params
|
||||
assert 'int_value' in params
|
||||
assert 'float_value' in params
|
||||
assert 'bool_value' in params
|
||||
assert 'null_value' in params
|
||||
assert 'list_value' in params
|
||||
assert 'dict_value' in params
|
||||
assert 'custom_object' not in params
|
||||
|
||||
def test_params_filters_nested_non_serializable(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'nested_list_with_bad': ['a', CustomObject(), 'c'],
|
||||
'nested_dict_with_bad': {'good': 'value', 'bad': CustomObject()},
|
||||
'good_nested_list': ['a', ['b', 'c']],
|
||||
'good_nested_dict': {'outer': {'inner': 'value'}},
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'nested_list_with_bad' not in params
|
||||
assert 'nested_dict_with_bad' not in params
|
||||
assert 'good_nested_list' in params
|
||||
assert 'good_nested_dict' in params
|
||||
|
||||
def test_is_json_serializable_primitives_and_collections(self):
|
||||
assert QueryEntryAdapter.is_json_serializable(None) is True
|
||||
assert QueryEntryAdapter.is_json_serializable('string') is True
|
||||
assert QueryEntryAdapter.is_json_serializable(42) is True
|
||||
assert QueryEntryAdapter.is_json_serializable(['a', 'b']) is True
|
||||
assert QueryEntryAdapter.is_json_serializable({'key': 'value'}) is True
|
||||
assert QueryEntryAdapter.is_json_serializable((1, 2, 3)) is True
|
||||
|
||||
def test_is_json_serializable_rejects_sets_and_objects(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
assert QueryEntryAdapter.is_json_serializable(CustomObject()) is False
|
||||
assert QueryEntryAdapter.is_json_serializable({1, 2, 3}) is False
|
||||
assert QueryEntryAdapter.is_json_serializable([1, {2, 3}]) is False
|
||||
assert QueryEntryAdapter.is_json_serializable({'key': {1, 2}}) is False
|
||||
|
||||
|
||||
class TestBuildAdapterContext:
|
||||
"""Tests for QueryEntryAdapter.build_adapter_context."""
|
||||
|
||||
def test_adapter_context_does_not_push_prompt(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {},
|
||||
'query_id': 123,
|
||||
'prompt': object(),
|
||||
})()
|
||||
|
||||
context = QueryEntryAdapter.build_adapter_context(query, binding=None)
|
||||
|
||||
assert context == {'params': {}, 'query_id': 123}
|
||||
@@ -0,0 +1,361 @@
|
||||
"""Tests for ContextAccess.state determination in AgentRunContextBuilder.
|
||||
|
||||
Tests focus on:
|
||||
- Event-first mode: state=True when enable_state=True and state_scopes non-empty
|
||||
- Event-first mode: state=False when enable_state=False
|
||||
- Legacy Query mode: state=False (no persistent state API)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application for testing."""
|
||||
def __init__(self):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock()
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id='plugin:test/runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
permissions=permissions
|
||||
if permissions is not None
|
||||
else {
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestContextAccessStateDetermination:
|
||||
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
return AgentEventEnvelope(
|
||||
event_id='evt_001',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id='conv_001',
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_descriptor(self):
|
||||
"""Create mock runner descriptor."""
|
||||
return make_descriptor()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=True when enable_state=True and state_scopes non-empty."""
|
||||
# Create binding with state enabled and non-empty scopes
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call to _build_context_access
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=True based on binding.state_policy
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_false_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=False."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_empty_scopes_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=True but state_scopes empty."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=[], # Empty scopes - state not available
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False (empty scopes means state not available)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_binding_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when no binding is provided."""
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call without binding
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding=None)
|
||||
|
||||
# Verify state=False (no binding = no state policy = state disabled)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_scope_available_without_conversation(self, mock_app, mock_descriptor):
|
||||
"""State API with runner scope is available even without conversation_id."""
|
||||
mock_event = AgentEventEnvelope(
|
||||
event_id='evt_002',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id=None, # No conversation
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_002',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='workspace', scope_id='ws_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['runner'], # Runner scope doesn't need conversation_id
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True because runner scope is enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_scopes_all_available(self, mock_app, mock_event, mock_descriptor):
|
||||
"""State API with multiple scopes enabled."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_003',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True with all scopes enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
|
||||
class TestStatePolicyFromBinding:
|
||||
"""Tests for state_policy extraction from binding."""
|
||||
|
||||
def test_state_policy_structure(self):
|
||||
"""State policy has correct structure."""
|
||||
policy = StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
)
|
||||
|
||||
assert policy.enable_state is True
|
||||
assert len(policy.state_scopes) == 4
|
||||
assert 'conversation' in policy.state_scopes
|
||||
|
||||
def test_state_policy_disabled(self):
|
||||
"""State policy can be disabled."""
|
||||
policy = StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
)
|
||||
|
||||
assert policy.enable_state is False
|
||||
assert len(policy.state_scopes) == 0
|
||||
|
||||
|
||||
class TestBindingWithStatePolicy:
|
||||
"""Tests for binding with state_policy."""
|
||||
|
||||
def test_binding_contains_state_policy(self):
|
||||
"""Binding contains state_policy field."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation'],
|
||||
),
|
||||
)
|
||||
|
||||
assert binding.state_policy is not None
|
||||
assert binding.state_policy.enable_state is True
|
||||
|
||||
|
||||
class TestContextAccessOtherAPIs:
|
||||
"""Tests for other available_apis fields based on run scope."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_apis_enabled_with_conversation(self, mock_app):
|
||||
"""History APIs are available when the run has a conversation scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['prompt_get'] is False
|
||||
assert context_access['available_apis']['history_page'] is True
|
||||
assert context_access['available_apis']['history_search'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_apis_enabled_by_default(self, mock_app):
|
||||
"""Event APIs are available based on current run scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_required_apis_disabled_without_conversation(self, mock_app):
|
||||
"""Conversation-scoped APIs are disabled when the run has no conversation."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = None
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manifest_permissions_disable_context_apis(self, mock_app):
|
||||
"""Pull APIs are disabled when manifest permissions omit them."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor(permissions={})
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is False
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['storage'] is False
|
||||
@@ -0,0 +1,428 @@
|
||||
"""Test that LangBot context builder output validates against SDK AgentRunContext."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# SDK imports for validation
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context import AgentRunContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import AgentEventContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context_access import ContextAccess
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
|
||||
# LangBot imports
|
||||
from langbot.pkg.agent.runner.context_builder import (
|
||||
AgentRunContextBuilder,
|
||||
AgentResources as BuilderResources,
|
||||
)
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class TestContextValidation:
|
||||
"""Test that context builder output validates against SDK AgentRunContext."""
|
||||
|
||||
def _make_mock_app(self):
|
||||
"""Create a mock application."""
|
||||
mock_app = MagicMock(spec=app.Application)
|
||||
mock_app.ver_mgr = MagicMock()
|
||||
mock_app.ver_mgr.get_current_version = MagicMock(return_value="1.0.0")
|
||||
mock_app.persistence_mgr = MagicMock()
|
||||
mock_app.persistence_mgr.get_db_engine = MagicMock()
|
||||
mock_app.logger = MagicMock()
|
||||
return mock_app
|
||||
|
||||
def _make_event_envelope(self) -> AgentEventEnvelope:
|
||||
"""Create a test event envelope."""
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput as EventInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
return AgentEventEnvelope(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
event_time=1700000000,
|
||||
source="platform",
|
||||
source_event_type="platform.message",
|
||||
bot_id="bot_1",
|
||||
workspace_id="workspace_1",
|
||||
conversation_id="conv_1",
|
||||
thread_id=None,
|
||||
actor=ActorContext(
|
||||
actor_type="user",
|
||||
actor_id="user_1",
|
||||
actor_name="Test User",
|
||||
),
|
||||
subject=None,
|
||||
input=EventInput(text="Hello world"),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
data={"platform_event_id": "source_evt_1"},
|
||||
)
|
||||
|
||||
def _make_binding(self) -> AgentBinding:
|
||||
"""Create a test binding."""
|
||||
return AgentBinding(
|
||||
binding_id="binding_1",
|
||||
scope=BindingScope(scope_type="agent", scope_id="pipeline_1"),
|
||||
event_types=["message.received"],
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
runner_config={"timeout": 300},
|
||||
agent_id="pipeline_1",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def _make_resources(self) -> BuilderResources:
|
||||
"""Create test resources."""
|
||||
return {
|
||||
'models': [],
|
||||
'tools': [],
|
||||
'knowledge_bases': [],
|
||||
'skills': [],
|
||||
'files': [],
|
||||
'storage': {'plugin_storage': True, 'workspace_storage': True},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
def _make_descriptor(self):
|
||||
"""Create a mock runner descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id="plugin:test/plugin/runner",
|
||||
source="plugin",
|
||||
label={"en_US": "Test Runner"},
|
||||
plugin_author="test",
|
||||
plugin_name="plugin",
|
||||
runner_name="runner",
|
||||
permissions={
|
||||
"history": ["page", "search"],
|
||||
"events": ["get", "page"],
|
||||
"storage": ["plugin", "workspace"],
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_validates(self):
|
||||
"""Test that build_context_from_event output validates against SDK AgentRunContext."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
# Build context
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Validate it can be parsed by SDK AgentRunContext
|
||||
# This will raise ValidationError if invalid
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
|
||||
# Verify required fields
|
||||
assert validated.run_id is not None
|
||||
assert validated.event is not None
|
||||
assert isinstance(validated.event, AgentEventContext)
|
||||
assert validated.delivery is not None
|
||||
assert isinstance(validated.delivery, DeliveryContext)
|
||||
assert validated.context is not None
|
||||
assert isinstance(validated.context, ContextAccess)
|
||||
assert validated.input is not None
|
||||
assert isinstance(validated.input, AgentInput)
|
||||
assert validated.resources is not None
|
||||
assert isinstance(validated.resources, AgentResources)
|
||||
assert validated.runtime is not None
|
||||
assert isinstance(validated.runtime, AgentRuntimeContext)
|
||||
assert "protocol_version" not in validated.runtime.model_dump()
|
||||
assert "sdk_protocol_version" not in validated.runtime.model_dump()
|
||||
assert "sdk_protocol_version" not in context_dict["runtime"]
|
||||
|
||||
# Verify event context
|
||||
assert validated.event.event_id == "evt_1"
|
||||
assert validated.event.event_type == "message.received"
|
||||
assert validated.event.source == "platform"
|
||||
assert validated.event.source_event_type == "platform.message"
|
||||
assert validated.event.data == {"platform_event_id": "source_evt_1"}
|
||||
|
||||
# Verify conversation context uses SDK field names
|
||||
assert validated.conversation is not None
|
||||
assert validated.conversation.bot_id == "bot_1"
|
||||
assert validated.conversation.workspace_id == "workspace_1"
|
||||
|
||||
# Verify delivery context
|
||||
assert validated.delivery.surface == "test"
|
||||
|
||||
# Verify input
|
||||
assert validated.input.text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_populates_model_context_window(self):
|
||||
"""Runtime metadata should expose the selected LLM model context window."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=128000),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'rerank-model',
|
||||
'model_type': 'rerank',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['rerank'],
|
||||
},
|
||||
{
|
||||
'model_id': 'llm-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
assert context_dict['runtime']['metadata']['model_context_window_tokens'] == 128000
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('llm-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_context_window_uses_primary_llm_only(self):
|
||||
"""Fallback model windows should not replace missing primary model metadata."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=None),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'primary-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
{
|
||||
'model_id': 'fallback-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
|
||||
assert await builder._build_model_context_window_tokens(resources) is None
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('primary-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_preserves_subject_data_for_non_message_events(self):
|
||||
"""Non-message EBA events keep subject.data instead of relying on message text."""
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext, SubjectContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput as EventInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
event = AgentEventEnvelope(
|
||||
event_id="evt_recall_1",
|
||||
event_type="message.recalled",
|
||||
event_time=1700000001,
|
||||
source="platform",
|
||||
source_event_type="platform.message.recall",
|
||||
bot_id="bot_1",
|
||||
workspace_id="workspace_1",
|
||||
conversation_id="conv_1",
|
||||
actor=ActorContext(actor_type="user", actor_id="user_1"),
|
||||
subject=SubjectContext(
|
||||
subject_type="message",
|
||||
subject_id="message_1",
|
||||
data={"recalled_message_id": "message_1", "reason": "user_recall"},
|
||||
),
|
||||
input=EventInput(text=None),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
data={"source_event_id": "source_recall_1"},
|
||||
)
|
||||
binding = self._make_binding()
|
||||
binding.event_types = ["message.recalled"]
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
|
||||
assert validated.event.event_type == "message.recalled"
|
||||
assert validated.input.text is None
|
||||
assert validated.subject is not None
|
||||
assert validated.subject.subject_type == "message"
|
||||
assert validated.subject.subject_id == "message_1"
|
||||
assert validated.subject.data == {"recalled_message_id": "message_1", "reason": "user_recall"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_has_no_legacy_top_level_fields(self):
|
||||
"""Test that build_context_from_event does NOT have top-level messages/prompt/params."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Protocol v1 does NOT have these as core fields
|
||||
assert 'messages' not in context_dict, "messages should not be top-level in Protocol v1"
|
||||
assert 'prompt' not in context_dict, "prompt should not be top-level in Protocol v1"
|
||||
assert 'params' not in context_dict, "params should not be top-level in Protocol v1"
|
||||
|
||||
# Protocol v1 DOES have these
|
||||
assert 'delivery' in context_dict, "delivery is REQUIRED in Protocol v1"
|
||||
assert 'context' in context_dict, "context (ContextAccess) is REQUIRED in Protocol v1"
|
||||
assert 'bootstrap' not in context_dict, "Host must not inline bootstrap/history windows"
|
||||
assert 'adapter' in context_dict, "adapter should exist"
|
||||
assert 'metadata' in context_dict, "metadata should exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_event_is_not_none(self):
|
||||
"""Test that event field is NOT None in Protocol v1."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# event is REQUIRED in Protocol v1
|
||||
assert context_dict.get('event') is not None, "event is REQUIRED for Protocol v1"
|
||||
|
||||
# Validate
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
assert validated.event is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_delivery_is_not_none(self):
|
||||
"""Test that delivery field is NOT None in Protocol v1."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# delivery is REQUIRED in Protocol v1
|
||||
assert context_dict.get('delivery') is not None, "delivery is REQUIRED for Protocol v1"
|
||||
|
||||
# Validate
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
assert validated.delivery is not None
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Tests for event-first Protocol v1 entities and Query entry adapter.
|
||||
|
||||
Tests cover:
|
||||
1. Query -> AgentEventEnvelope conversion
|
||||
2. Current config -> AgentConfig projection and single-binding resolution
|
||||
3. AgentRunContext not inlining full history by default
|
||||
4. LangBot Host not defining context-window controls
|
||||
5. Event-first run() entry point
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Import SDK entities
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import (
|
||||
AgentEventContext,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.trigger import AgentTrigger
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context import AgentRunContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.result import (
|
||||
AgentRunResult,
|
||||
)
|
||||
|
||||
# Import LangBot host models
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
from langbot.pkg.agent.runner.binding_resolver import (
|
||||
AgentBindingResolver,
|
||||
AgentBindingResolutionError,
|
||||
)
|
||||
|
||||
|
||||
class TestQueryToEventEnvelope:
|
||||
"""Test Query -> AgentEventEnvelope conversion."""
|
||||
|
||||
def test_query_to_event_basic_fields(self, mock_query):
|
||||
"""Test basic field conversion from Query to Event envelope."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.event_type == "message.received"
|
||||
assert event.source == "host_adapter"
|
||||
assert event.bot_id == mock_query.bot_uuid
|
||||
assert event.actor is not None
|
||||
assert event.actor.actor_type == "user"
|
||||
|
||||
def test_query_to_event_input(self, mock_query):
|
||||
"""Test input conversion from Query."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.input is not None
|
||||
assert event.input.text == "Hello world"
|
||||
assert "message_chain" not in event.input.model_dump()
|
||||
|
||||
def test_query_to_event_conversation(self, mock_query):
|
||||
"""Test conversation context extraction."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "conv-uuid-123"
|
||||
|
||||
def test_query_to_event_prefers_variable_conversation_id_when_conversation_uuid_missing(self, mock_query):
|
||||
"""Pipeline variables can provide the conversation identity for state scope."""
|
||||
mock_query.session.using_conversation.uuid = None
|
||||
mock_query.variables["conversation_id"] = "conv-from-vars"
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "conv-from-vars"
|
||||
|
||||
def test_query_to_event_falls_back_to_launcher_session_for_state_scope(self, mock_query):
|
||||
"""Debug Chat and legacy pipeline runs may not have a conversation UUID."""
|
||||
mock_query.session.using_conversation.uuid = None
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "person_launcher-123"
|
||||
|
||||
def test_query_to_event_delivery_context(self, mock_query):
|
||||
"""Test delivery context extraction."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.delivery is not None
|
||||
assert event.delivery.surface == "platform"
|
||||
assert isinstance(event.delivery.supports_streaming, bool)
|
||||
|
||||
def test_query_to_event_preserves_source_event_data(self, mock_query):
|
||||
"""Test source event metadata survives the adapter boundary."""
|
||||
source_event = Mock()
|
||||
source_event.type = "platform.message.created"
|
||||
source_event.time = 1700000000
|
||||
source_event.sender = None
|
||||
source_event.model_dump = Mock(return_value={
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
"source_platform_object": {"large": "payload"},
|
||||
})
|
||||
mock_query.message_event = source_event
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.source_event_type == "platform.message.created"
|
||||
assert event.event_time == 1700000000
|
||||
assert event.data == {
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_keeps_large_payloads_out_of_event_data(self, mock_query):
|
||||
"""Large or nested platform payloads should not be duplicated into event.data."""
|
||||
source_event = Mock()
|
||||
source_event.type = "platform.message.created"
|
||||
source_event.time = 1700000000
|
||||
source_event.sender = None
|
||||
source_event.model_dump = Mock(return_value={
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
"message_chain": [{"type": "Image", "base64": "data:image/png;base64," + ("a" * 1024)}],
|
||||
"raw_text": "x" * 1024,
|
||||
"source_platform_object": {"large": "payload"},
|
||||
})
|
||||
mock_query.message_event = source_event
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.data == {
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_handles_missing_message_chain(self, mock_query):
|
||||
"""Test delivery context building when Query has no message_chain."""
|
||||
delattr(mock_query, "message_chain")
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.delivery.reply_target == {"message_id": None}
|
||||
|
||||
def test_query_to_event_scopes_pipeline_local_event_ids(self, mock_query):
|
||||
"""Pipeline-local message IDs must not become global audit IDs."""
|
||||
first = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
mock_query.launcher_id = "launcher-456"
|
||||
second = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert first.event_id.startswith("host:")
|
||||
assert first.event_id != "789"
|
||||
assert second.event_id != first.event_id
|
||||
|
||||
|
||||
class TestQueryConfigToAgentConfig:
|
||||
"""Test current config projection and single-Agent binding resolution."""
|
||||
|
||||
def test_config_to_agent_config_runner_id(self, mock_query):
|
||||
"""Test AgentConfig runner_id extraction."""
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:author/plugin/runner"
|
||||
)
|
||||
|
||||
assert agent_config.runner_id == "plugin:author/plugin/runner"
|
||||
|
||||
def test_config_to_agent_config_uses_legacy_runner_config_migration(self, mock_query):
|
||||
"""Temporary query adapter must share the normal runner config resolver."""
|
||||
mock_query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": "model-primary",
|
||||
"knowledge-base": "kb-001",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query,
|
||||
"plugin:langbot/local-agent/default",
|
||||
)
|
||||
|
||||
assert agent_config.runner_config["model"] == {
|
||||
"primary": "model-primary",
|
||||
"fallbacks": [],
|
||||
}
|
||||
assert agent_config.runner_config["knowledge-bases"] == ["kb-001"]
|
||||
|
||||
def test_resolver_projects_agent_scope(self, mock_query):
|
||||
"""Test binding scope projection through the resolver."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
binding = AgentBindingResolver().resolve_one(event, [agent_config])
|
||||
|
||||
assert binding.scope.scope_type == "agent"
|
||||
assert binding.scope.scope_id == mock_query.pipeline_uuid
|
||||
assert binding.agent_id == mock_query.pipeline_uuid
|
||||
|
||||
def test_resolver_rejects_multiple_matching_agents(self, mock_query):
|
||||
"""Event dispatch is single-Agent in v1."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
first = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
second = first.model_copy(update={"agent_id": "agent_2"})
|
||||
|
||||
with pytest.raises(AgentBindingResolutionError):
|
||||
AgentBindingResolver().resolve_one(event, [first, second])
|
||||
|
||||
class TestAgentRunContextProtocolV1:
|
||||
"""Test AgentRunContext Protocol v1 behavior."""
|
||||
|
||||
def test_sdk_context_event_required(self):
|
||||
"""Test that event is required in Protocol v1 context."""
|
||||
trigger = AgentTrigger(type="message.received")
|
||||
event = AgentEventContext(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
)
|
||||
input = AgentInput(text="Hello")
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
ctx = AgentRunContext(
|
||||
run_id="run_1",
|
||||
trigger=trigger,
|
||||
event=event,
|
||||
input=input,
|
||||
delivery=DeliveryContext(surface="platform"),
|
||||
resources=AgentResources(),
|
||||
runtime=AgentRuntimeContext(),
|
||||
)
|
||||
|
||||
assert ctx.event is not None
|
||||
assert ctx.event.event_type == "message.received"
|
||||
|
||||
def test_sdk_context_has_no_history_message_fields(self):
|
||||
"""AgentRunContext should not expose inline history message fields."""
|
||||
trigger = AgentTrigger(type="message.received")
|
||||
event = AgentEventContext(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
)
|
||||
input = AgentInput(text="Hello")
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
ctx = AgentRunContext(
|
||||
run_id="run_1",
|
||||
trigger=trigger,
|
||||
event=event,
|
||||
input=input,
|
||||
delivery=DeliveryContext(surface="platform"),
|
||||
resources=AgentResources(),
|
||||
runtime=AgentRuntimeContext(),
|
||||
)
|
||||
|
||||
assert "messages" not in AgentRunContext.model_fields
|
||||
assert "bootstrap" not in AgentRunContext.model_fields
|
||||
assert not hasattr(ctx, "bootstrap")
|
||||
|
||||
|
||||
class TestHostManagedHistoryNotInProtocol:
|
||||
"""Test that Host-managed history payloads are not in Protocol v1."""
|
||||
|
||||
def test_messages_not_in_sdk_context_top_level(self):
|
||||
"""AgentRunContext should not expose top-level history messages."""
|
||||
ctx_fields = AgentRunContext.model_fields.keys()
|
||||
|
||||
assert "messages" not in ctx_fields
|
||||
|
||||
|
||||
class TestSDKResultProtocolV1:
|
||||
"""Test SDK AgentRunResult for Protocol v1."""
|
||||
|
||||
def test_result_requires_run_id(self):
|
||||
"""Test result requires run_id for Protocol v1."""
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
result = AgentRunResult.message_completed(
|
||||
run_id="run_1",
|
||||
message=Message(role="assistant", content="Hello"),
|
||||
)
|
||||
|
||||
assert result.run_id == "run_1"
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_query():
|
||||
"""Create a mock query for testing."""
|
||||
query = Mock()
|
||||
query.query_id = 123
|
||||
query.bot_uuid = "bot-uuid-123"
|
||||
query.pipeline_uuid = "pipeline-uuid-456"
|
||||
query.launcher_type = Mock(value="person")
|
||||
query.launcher_id = "launcher-123"
|
||||
query.sender_id = "sender-123"
|
||||
query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": "plugin:test/plugin/runner",
|
||||
}
|
||||
}
|
||||
query.variables = {}
|
||||
|
||||
# Create a proper content element mock
|
||||
content_elem = Mock(spec=['type', 'text', 'model_dump'])
|
||||
content_elem.type = 'text'
|
||||
content_elem.text = 'Hello world'
|
||||
content_elem.model_dump = Mock(return_value={'type': 'text', 'text': 'Hello world'})
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = [content_elem]
|
||||
|
||||
# Create message_chain mock
|
||||
message_chain = Mock()
|
||||
message_chain.message_id = 789
|
||||
message_chain.model_dump = Mock(return_value={'message_id': 789, 'components': []})
|
||||
query.message_chain = message_chain
|
||||
|
||||
query.message_event = None
|
||||
|
||||
# Mock session with proper conversation
|
||||
query.session = Mock()
|
||||
query.session.launcher_type = Mock(value="person")
|
||||
query.session.launcher_id = "launcher-123"
|
||||
query.session.using_conversation = Mock()
|
||||
query.session.using_conversation.uuid = "conv-uuid-123"
|
||||
|
||||
# Mock use_funcs (empty list by default)
|
||||
query.use_funcs = []
|
||||
query.use_llm_model_uuid = None
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_no_session():
|
||||
"""Create a mock Query without session."""
|
||||
query = Mock()
|
||||
query.query_id = 456
|
||||
query.bot_uuid = "bot-uuid-456"
|
||||
query.pipeline_uuid = "pipeline-uuid-789"
|
||||
query.launcher_type = Mock(value="person")
|
||||
query.launcher_id = "launcher-456"
|
||||
query.sender_id = "sender-456"
|
||||
query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": "plugin:test/plugin/runner",
|
||||
}
|
||||
}
|
||||
query.variables = {}
|
||||
|
||||
# Create a proper content element mock
|
||||
content_elem = Mock(spec=['type', 'text', 'model_dump'])
|
||||
content_elem.type = 'text'
|
||||
content_elem.text = 'Test message'
|
||||
content_elem.model_dump = Mock(return_value={'type': 'text', 'text': 'Test message'})
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = [content_elem]
|
||||
|
||||
message_chain = Mock()
|
||||
message_chain.message_id = -1
|
||||
message_chain.model_dump = Mock(return_value={'message_id': -1, 'components': []})
|
||||
query.message_chain = message_chain
|
||||
|
||||
query.message_event = None
|
||||
query.session = None
|
||||
|
||||
# Mock use_funcs
|
||||
query.use_funcs = []
|
||||
query.use_llm_model_uuid = None
|
||||
|
||||
return query
|
||||
@@ -0,0 +1,795 @@
|
||||
"""Tests for EventLog, Transcript, and history/event APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.host_models import (
|
||||
AgentEventEnvelope,
|
||||
AgentBinding,
|
||||
BindingScope,
|
||||
ResourcePolicy,
|
||||
StatePolicy,
|
||||
DeliveryPolicy,
|
||||
)
|
||||
from langbot.pkg.agent.runner.event_log_store import EventLogStore
|
||||
from langbot.pkg.agent.runner.transcript_store import TranscriptStore
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import (
|
||||
ActorContext,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
|
||||
def make_event_envelope(
|
||||
event_id: str = "evt_1",
|
||||
event_type: str = "message.received",
|
||||
conversation_id: str | None = "conv_1",
|
||||
actor_id: str | None = "user_1",
|
||||
input_text: str = "Hello",
|
||||
) -> AgentEventEnvelope:
|
||||
"""Create a test event envelope."""
|
||||
return AgentEventEnvelope(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
event_time=1700000000,
|
||||
source="platform",
|
||||
bot_id="bot_1",
|
||||
workspace_id=None,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=None,
|
||||
actor=ActorContext(
|
||||
actor_type="user",
|
||||
actor_id=actor_id,
|
||||
actor_name="Test User",
|
||||
) if actor_id else None,
|
||||
subject=None,
|
||||
input=AgentInput(text=input_text),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
)
|
||||
|
||||
|
||||
def make_binding(runner_id: str = "plugin:test/plugin/runner") -> AgentBinding:
|
||||
"""Create a test binding."""
|
||||
return AgentBinding(
|
||||
binding_id="binding_1",
|
||||
scope=BindingScope(scope_type="agent", scope_id="pipeline_1"),
|
||||
event_types=["message.received"],
|
||||
runner_id=runner_id,
|
||||
runner_config={},
|
||||
resource_policy=ResourcePolicy(),
|
||||
state_policy=StatePolicy(),
|
||||
delivery_policy=DeliveryPolicy(),
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
class TestEventLogStore:
|
||||
"""Test EventLogStore operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event(self, mock_db_engine):
|
||||
"""Test appending an event to EventLog."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
bot_id="bot_1",
|
||||
conversation_id="conv_1",
|
||||
actor_type="user",
|
||||
actor_id="user_1",
|
||||
input_summary="Hello world",
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
assert event_id == "evt_1"
|
||||
stored_event = mock_session.add.call_args.args[0]
|
||||
assert stored_event.metadata_json is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_stores_metadata_json(self, mock_db_engine):
|
||||
"""EventLog metadata records steering dispatch/audit facts."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_steering",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
metadata={
|
||||
"steering": {
|
||||
"status": "queued",
|
||||
"claimed_by_run_id": "run_1",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert event_id == "evt_steering"
|
||||
stored_event = mock_session.add.call_args.args[0]
|
||||
assert '"status": "queued"' in stored_event.metadata_json
|
||||
assert '"claimed_by_run_id": "run_1"' in stored_event.metadata_json
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_truncates_input_summary(self, mock_db_engine):
|
||||
"""Test that long input summaries are truncated."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
long_text = "x" * 2000
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_2",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
input_summary=long_text,
|
||||
)
|
||||
|
||||
assert event_id == "evt_2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events_with_conversation_filter(self, mock_db_engine):
|
||||
"""Test paging events with conversation_id filter."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
|
||||
class TestTranscriptStore:
|
||||
"""Test TranscriptStore operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript(self, mock_db_engine):
|
||||
"""Test appending a transcript item."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Mock _get_next_seq
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_1",
|
||||
conversation_id="conv_1",
|
||||
role="user",
|
||||
content="Hello",
|
||||
)
|
||||
|
||||
assert transcript_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript_with_attachments(self, mock_db_engine):
|
||||
"""Test appending transcript with attachment refs."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_2",
|
||||
conversation_id="conv_1",
|
||||
role="assistant",
|
||||
content="Here's an image",
|
||||
attachment_refs=[
|
||||
{"id": "att_1", "type": "image", "url": "http://example.com/img.png"}
|
||||
],
|
||||
)
|
||||
|
||||
assert transcript_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_backward(self, mock_db_engine):
|
||||
"""Test paging transcript backward (older items)."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
direction="backward",
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
|
||||
"""Test that transcript paging has a hard limit."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
# Request more than the hard limit
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
limit=200, # Request 200, but hard limit is 100
|
||||
)
|
||||
|
||||
# The store should cap at 100
|
||||
assert len(items) <= store.HARD_LIMIT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript(self, mock_db_engine):
|
||||
"""Test searching transcript."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_1",
|
||||
query_text="database",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
|
||||
class TestHistoryPageAuthorization:
|
||||
"""Test history.page authorization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_requires_run_id(self, mock_handler, mock_db_engine):
|
||||
"""Test history.page requires run_id."""
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
# Mock call_action to simulate the handler
|
||||
result = await mock_handler.call_action(
|
||||
PluginToRuntimeAction.HISTORY_PAGE,
|
||||
{"run_id": None},
|
||||
)
|
||||
|
||||
# Should return error
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_validates_conversation_scope(self, mock_db_engine):
|
||||
"""Test history.page only allows access to run's conversation."""
|
||||
# This test verifies the authorization logic
|
||||
# The actual implementation validates conversation_id matches session
|
||||
session_registry = get_session_registry()
|
||||
|
||||
await session_registry.register(
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
query_id=None,
|
||||
plugin_identity="test/plugin",
|
||||
resources={"models": [], "tools": [], "knowledge_bases": [], "storage": {"plugin_storage": True}},
|
||||
conversation_id="conv_1",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_1")
|
||||
assert session is not None
|
||||
assert session["authorization"]["conversation_id"] == "conv_1"
|
||||
|
||||
# Cleanup
|
||||
await session_registry.unregister("run_1")
|
||||
|
||||
|
||||
class TestEventGetAuthorization:
|
||||
"""Test event.get authorization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_requires_run_id(self, mock_handler):
|
||||
"""Test event.get requires run_id."""
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
result = await mock_handler.call_action(
|
||||
PluginToRuntimeAction.EVENT_GET,
|
||||
{"run_id": None, "event_id": "evt_1"},
|
||||
)
|
||||
|
||||
# Should return error
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
|
||||
class TestContextAccessPopulation:
|
||||
"""Test ContextAccess population in build_context_from_event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
|
||||
"""Test ContextAccess shows available APIs based on permissions."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
cursor = await store.get_latest_cursor("conv_1")
|
||||
# Should return None or a cursor string
|
||||
assert cursor is None or isinstance(cursor, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_shows_has_history_before(self, mock_db_engine):
|
||||
"""Test ContextAccess indicates if history exists."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar.return_value = 0
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
has_history = await store.has_history_before("conv_1", 10)
|
||||
assert isinstance(has_history, bool)
|
||||
|
||||
|
||||
class TestEventLogStoreRealSQLite:
|
||||
"""Test EventLogStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_get_event_round_trip(self, db_engine):
|
||||
"""Test append_event -> get_event round trip with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append event
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_real_001",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
bot_id="bot_001",
|
||||
conversation_id="conv_001",
|
||||
actor_type="user",
|
||||
actor_id="user_001",
|
||||
actor_name="Test User",
|
||||
input_summary="Hello world",
|
||||
run_id="run_001",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
assert event_id == "evt_real_001"
|
||||
|
||||
# Get event
|
||||
event = await store.get_event(event_id)
|
||||
assert event is not None
|
||||
assert event["event_id"] == "evt_real_001"
|
||||
assert event["event_type"] == "message.received"
|
||||
assert event["source"] == "platform"
|
||||
assert event["conversation_id"] == "conv_001"
|
||||
assert event["actor_type"] == "user"
|
||||
assert event["actor_id"] == "user_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events(self, db_engine):
|
||||
"""Test page_events with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append multiple events
|
||||
for i in range(5):
|
||||
await store.append_event(
|
||||
event_id=f"evt_real_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_001",
|
||||
input_summary=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page events
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_001",
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert has_more is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append events
|
||||
for i in range(3):
|
||||
await store.append_event(
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cursor",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_events_older_than(self, db_engine):
|
||||
"""EventLog cleanup removes only rows older than the cutoff."""
|
||||
import sqlalchemy
|
||||
from langbot.pkg.entity.persistence.event_log import EventLog
|
||||
|
||||
store = EventLogStore(db_engine)
|
||||
cutoff = datetime.datetime.utcnow()
|
||||
await store.append_event(
|
||||
event_id="evt_cleanup_old",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cleanup",
|
||||
)
|
||||
await store.append_event(
|
||||
event_id="evt_cleanup_new",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cleanup",
|
||||
)
|
||||
async with store._session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(EventLog)
|
||||
.where(EventLog.event_id == "evt_cleanup_old")
|
||||
.values(created_at=cutoff - datetime.timedelta(days=2))
|
||||
)
|
||||
await session.execute(
|
||||
sqlalchemy.update(EventLog)
|
||||
.where(EventLog.event_id == "evt_cleanup_new")
|
||||
.values(created_at=cutoff + datetime.timedelta(days=2))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
removed = await store.cleanup_events_older_than(cutoff)
|
||||
|
||||
assert removed == 1
|
||||
assert await store.get_event("evt_cleanup_old") is None
|
||||
assert await store.get_event("evt_cleanup_new") is not None
|
||||
|
||||
|
||||
class TestTranscriptStoreRealSQLite:
|
||||
"""Test TranscriptStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_page_transcript_round_trip(self, db_engine):
|
||||
"""Test append_transcript -> page_transcript round trip with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_real_{i:03d}",
|
||||
event_id=f"evt_{i:03d}",
|
||||
conversation_id="conv_001",
|
||||
role="user" if i % 2 == 0 else "assistant",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page transcript
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_001",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert items[0]["conversation_id"] == "conv_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_legacy_provider_messages_projects_transcript_history(self, db_engine):
|
||||
"""Transcript is the canonical source; legacy Pipeline readers get a Message view."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_001",
|
||||
event_id="evt_view_001",
|
||||
conversation_id="conv_view",
|
||||
role="user",
|
||||
content="User text",
|
||||
content_json={
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "User structured text"}],
|
||||
},
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_002",
|
||||
event_id="evt_view_002",
|
||||
conversation_id="conv_view",
|
||||
role="tool",
|
||||
item_type="tool_result",
|
||||
content="ignored tool result",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_003",
|
||||
event_id="evt_view_003",
|
||||
conversation_id="conv_view",
|
||||
role="assistant",
|
||||
content="Assistant text",
|
||||
)
|
||||
|
||||
messages = await store.get_legacy_provider_messages("conv_view")
|
||||
|
||||
assert [message.role for message in messages] == ["user", "assistant"]
|
||||
assert messages[0].content[0].text == "User structured text"
|
||||
assert messages[1].content == "Assistant text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_legacy_provider_messages_filters_scope(self, db_engine):
|
||||
"""Legacy Pipeline history projection must stay inside the current run scope."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_001",
|
||||
event_id="evt_scope_001",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="user",
|
||||
content="Current scope text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_002",
|
||||
event_id="evt_scope_002",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_002",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="assistant",
|
||||
content="Other bot text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_003",
|
||||
event_id="evt_scope_003",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_002",
|
||||
role="assistant",
|
||||
content="Other thread text",
|
||||
)
|
||||
|
||||
messages = await store.get_legacy_provider_messages(
|
||||
"conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
strict_thread=True,
|
||||
)
|
||||
|
||||
assert [message.content for message in messages] == ["Current scope text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript_real_db(self, db_engine):
|
||||
"""Test search_transcript with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_001",
|
||||
event_id="evt_search_001",
|
||||
conversation_id="conv_search",
|
||||
role="user",
|
||||
content="I want to learn about databases",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_002",
|
||||
event_id="evt_search_002",
|
||||
conversation_id="conv_search",
|
||||
role="assistant",
|
||||
content="Here is information about databases",
|
||||
)
|
||||
|
||||
# Search for "database"
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_search",
|
||||
query_text="database",
|
||||
)
|
||||
|
||||
# Should find at least one match
|
||||
assert len(items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor_real_db(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_cursor_{i:03d}",
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
conversation_id="conv_cursor",
|
||||
role="user",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_transcripts_older_than(self, db_engine):
|
||||
"""Transcript cleanup removes only rows older than the cutoff."""
|
||||
import sqlalchemy
|
||||
from langbot.pkg.entity.persistence.transcript import Transcript
|
||||
|
||||
store = TranscriptStore(db_engine)
|
||||
cutoff = datetime.datetime.utcnow()
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_cleanup_old",
|
||||
event_id="evt_cleanup_old",
|
||||
conversation_id="conv_cleanup",
|
||||
role="user",
|
||||
content="old",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_cleanup_new",
|
||||
event_id="evt_cleanup_new",
|
||||
conversation_id="conv_cleanup",
|
||||
role="assistant",
|
||||
content="new",
|
||||
)
|
||||
async with store._session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(Transcript)
|
||||
.where(Transcript.transcript_id == "trans_cleanup_old")
|
||||
.values(created_at=cutoff - datetime.timedelta(days=2))
|
||||
)
|
||||
await session.execute(
|
||||
sqlalchemy.update(Transcript)
|
||||
.where(Transcript.transcript_id == "trans_cleanup_new")
|
||||
.values(created_at=cutoff + datetime.timedelta(days=2))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
removed = await store.cleanup_transcripts_older_than(cutoff)
|
||||
items, _, _, _ = await store.page_transcript("conv_cleanup", limit=10)
|
||||
|
||||
assert removed == 1
|
||||
assert [item["content"] for item in items] == ["new"]
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_engine():
|
||||
"""Create a mock database engine for AsyncSession-based stores."""
|
||||
from unittest.mock import MagicMock
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine = MagicMock(spec=AsyncEngine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_handler():
|
||||
"""Create a mock handler for testing actions."""
|
||||
from langbot_plugin.runtime.io.handler import Handler
|
||||
|
||||
class MockHandler(Handler):
|
||||
def __init__(self):
|
||||
self._responses = {}
|
||||
|
||||
async def call_action(self, action, data, timeout=30):
|
||||
# Simulate error response for missing run_id
|
||||
if not data.get("run_id"):
|
||||
return {"ok": False, "message": "run_id is required"}
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
return MockHandler()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,323 @@
|
||||
"""Tests for AgentRunner history/event pull API authorization."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.event_log_store import EventLogStore
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.entity.persistence import event_log as event_log_model
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.page_results import (
|
||||
AgentEventRecord,
|
||||
EventPage,
|
||||
)
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
def __init__(self, db_engine):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry(monkeypatch):
|
||||
registry = AgentRunSessionRegistry()
|
||||
monkeypatch.setattr(
|
||||
'langbot.pkg.agent.runner.session_registry._global_registry',
|
||||
registry,
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
assert event_log_model.EventLog.__tablename__ == 'event_log'
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _handler(db_engine, session_registry):
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
fake_app = FakeApplication(db_engine)
|
||||
return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
|
||||
async def _register_session(
|
||||
session_registry,
|
||||
*,
|
||||
run_id='run_1',
|
||||
conversation_id='conv_1',
|
||||
bot_id=None,
|
||||
workspace_id=None,
|
||||
thread_id=None,
|
||||
available_apis=None,
|
||||
):
|
||||
await session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=None,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
conversation_id=conversation_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
available_apis=available_apis or {},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not authorized' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'conversation_id': 'conv_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_search': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value]
|
||||
|
||||
result = await history_search({
|
||||
'run_id': 'run_1',
|
||||
'query': 'hello',
|
||||
'filters': {'conversation_id': 'conv_other'},
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not authorized' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'conversation_id': 'conv_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_returns_sdk_record_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_get': True})
|
||||
store = EventLogStore(db_engine)
|
||||
event_id = await store.append_event(
|
||||
event_id='evt_projection_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
actor_type='user',
|
||||
actor_id='user_1',
|
||||
input_summary='hello',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_1',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_get = handler.actions[PluginToRuntimeAction.EVENT_GET.value]
|
||||
|
||||
result = await event_get({
|
||||
'run_id': 'run_1',
|
||||
'event_id': event_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
AgentEventRecord.model_validate(result.data)
|
||||
assert 'id' not in result.data
|
||||
assert 'input_json' not in result.data
|
||||
assert 'run_id' not in result.data
|
||||
assert 'runner_id' not in result.data
|
||||
assert result.data['seq'] is not None
|
||||
assert result.data['cursor'] == str(result.data['seq'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_returns_sdk_page_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_projection_page_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_other',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
page = EventPage.model_validate(result.data)
|
||||
assert len(page.items) == 1
|
||||
item = result.data['items'][0]
|
||||
assert 'id' not in item
|
||||
assert 'input_json' not in item
|
||||
assert 'run_id' not in item
|
||||
assert 'runner_id' not in item
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_filters_run_scope_thread_and_bot(session_registry, db_engine):
|
||||
from langbot.pkg.agent.runner.transcript_store import TranscriptStore
|
||||
|
||||
await _register_session(
|
||||
session_registry,
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'history_page': True},
|
||||
)
|
||||
store = TranscriptStore(db_engine)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_visible',
|
||||
event_id='evt_visible',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
content='visible',
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_other_bot',
|
||||
event_id='evt_other_bot',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_2',
|
||||
thread_id='thread_1',
|
||||
content='hidden bot',
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_other_thread',
|
||||
event_id='evt_other_thread',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_2',
|
||||
content='hidden thread',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
assert [item['content'] for item in result.data['items']] == ['visible']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_filters_run_scope_thread_and_bot(session_registry, db_engine):
|
||||
await _register_session(
|
||||
session_registry,
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'event_page': True},
|
||||
)
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_visible',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_1',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_1',
|
||||
)
|
||||
await store.append_event(
|
||||
event_id='evt_other_bot',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_2',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_1',
|
||||
)
|
||||
await store.append_event(
|
||||
event_id='evt_other_thread',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_1',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_2',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
assert [item['event_id'] for item in result.data['items']] == ['evt_visible']
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Tests for agent runner ID parsing and formatting."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.id import (
|
||||
parse_runner_id,
|
||||
format_runner_id,
|
||||
RunnerIdParts,
|
||||
is_plugin_runner_id,
|
||||
)
|
||||
|
||||
|
||||
class TestRunnerIdParsing:
|
||||
"""Tests for parse_runner_id."""
|
||||
|
||||
def test_parse_plugin_runner_id(self):
|
||||
"""Parse valid plugin runner ID."""
|
||||
runner_id = 'plugin:langbot/local-agent/default'
|
||||
parts = parse_runner_id(runner_id)
|
||||
|
||||
assert parts.source == 'plugin'
|
||||
assert parts.plugin_author == 'langbot'
|
||||
assert parts.plugin_name == 'local-agent'
|
||||
assert parts.runner_name == 'default'
|
||||
|
||||
def test_parse_plugin_runner_id_complex_names(self):
|
||||
"""Parse plugin runner ID with complex names."""
|
||||
runner_id = 'plugin:alice/helpdesk-agent/ticket-handler'
|
||||
parts = parse_runner_id(runner_id)
|
||||
|
||||
assert parts.source == 'plugin'
|
||||
assert parts.plugin_author == 'alice'
|
||||
assert parts.plugin_name == 'helpdesk-agent'
|
||||
assert parts.runner_name == 'ticket-handler'
|
||||
|
||||
def test_parse_invalid_plugin_runner_id_missing_parts(self):
|
||||
"""Parse invalid plugin runner ID with missing parts."""
|
||||
runner_id = 'plugin:langbot/local-agent'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'Invalid plugin runner ID format' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_plugin_runner_id_empty_parts(self):
|
||||
"""Parse invalid plugin runner ID with empty parts."""
|
||||
runner_id = 'plugin://default'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'non-empty' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_runner_id_not_plugin(self):
|
||||
"""Parse invalid runner ID without plugin prefix."""
|
||||
runner_id = 'local-agent'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'Invalid runner ID format' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_runner_id_empty_string(self):
|
||||
"""Parse empty runner ID."""
|
||||
runner_id = ''
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
|
||||
class TestRunnerIdFormatting:
|
||||
"""Tests for format_runner_id."""
|
||||
|
||||
def test_format_plugin_runner_id(self):
|
||||
"""Format plugin runner ID."""
|
||||
runner_id = format_runner_id(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_format_invalid_source(self):
|
||||
"""Format runner ID with invalid source."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
format_runner_id(
|
||||
source='builtin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert 'Invalid runner source' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRunnerIdParts:
|
||||
"""Tests for RunnerIdParts dataclass."""
|
||||
|
||||
def test_get_plugin_id(self):
|
||||
"""Get plugin ID from parts."""
|
||||
parts = RunnerIdParts(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert parts.to_plugin_id() == 'langbot/local-agent'
|
||||
|
||||
def test_frozen_dataclass(self):
|
||||
"""RunnerIdParts should be immutable."""
|
||||
parts = RunnerIdParts(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception): # FrozenInstanceError
|
||||
parts.plugin_author = 'other'
|
||||
|
||||
|
||||
class TestIsPluginRunnerId:
|
||||
"""Tests for is_plugin_runner_id."""
|
||||
|
||||
def test_is_plugin_runner_id_true(self):
|
||||
"""Check plugin runner ID returns True."""
|
||||
assert is_plugin_runner_id('plugin:langbot/local-agent/default') is True
|
||||
|
||||
def test_is_plugin_runner_id_false(self):
|
||||
"""Check non-plugin runner ID returns False."""
|
||||
assert is_plugin_runner_id('local-agent') is False
|
||||
assert is_plugin_runner_id('builtin:local-agent') is False
|
||||
assert is_plugin_runner_id('') is False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
"""Tests for agent runner registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.registry import AgentRunnerRegistry
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.errors import RunnerNotFoundError, RunnerNotAuthorizedError
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
|
||||
def __init__(self):
|
||||
class FakeLogger:
|
||||
def info(self, msg):
|
||||
pass
|
||||
|
||||
def debug(self, msg):
|
||||
pass
|
||||
|
||||
def warning(self, msg):
|
||||
pass
|
||||
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
self.logger = FakeLogger()
|
||||
|
||||
class FakePluginConnector:
|
||||
is_enable_plugin = True
|
||||
|
||||
async def list_agent_runners(self, bound_plugins=None):
|
||||
# Return sample runner data
|
||||
return [
|
||||
{
|
||||
'plugin_author': 'langbot',
|
||||
'plugin_name': 'local-agent',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
'name': 'default',
|
||||
'label': {'en_US': 'Local Agent'},
|
||||
'capabilities': {'streaming': True},
|
||||
'permissions': {},
|
||||
'config_schema': [],
|
||||
},
|
||||
},
|
||||
{
|
||||
'plugin_author': 'alice',
|
||||
'plugin_name': 'my-agent',
|
||||
'runner_name': 'custom',
|
||||
'manifest': {
|
||||
'id': 'plugin:alice/my-agent/custom',
|
||||
'name': 'custom',
|
||||
'label': {'en_US': 'Custom Agent'},
|
||||
'capabilities': {},
|
||||
'permissions': {},
|
||||
'config_schema': [{'name': 'param1', 'type': 'string'}],
|
||||
},
|
||||
},
|
||||
# Invalid runner - wrong kind
|
||||
{
|
||||
'plugin_author': 'bad',
|
||||
'plugin_name': 'wrong-kind',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'kind': 'Tool', # Wrong kind
|
||||
'metadata': {},
|
||||
'spec': {},
|
||||
},
|
||||
},
|
||||
# Invalid runner - missing name
|
||||
{
|
||||
'plugin_author': 'bad',
|
||||
'plugin_name': 'missing-name',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'kind': 'AgentRunner',
|
||||
'metadata': {}, # No name
|
||||
'spec': {},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
self.plugin_connector = FakePluginConnector()
|
||||
|
||||
|
||||
class TestRegistryDiscovery:
|
||||
"""Tests for runner discovery."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_valid_runners(self):
|
||||
"""Discover valid runners from plugin runtime."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
runners = await registry.list_runners(use_cache=False)
|
||||
|
||||
# Should find 2 valid runners (langbot/local-agent and alice/my-agent)
|
||||
assert len(runners) == 2
|
||||
|
||||
ids = [r.id for r in runners]
|
||||
assert 'plugin:langbot/local-agent/default' in ids
|
||||
assert 'plugin:alice/my-agent/custom' in ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_caches_results(self):
|
||||
"""Discovery should cache results."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# First discovery
|
||||
runners1 = await registry.list_runners(use_cache=True)
|
||||
|
||||
# Second call should use cache
|
||||
runners2 = await registry.list_runners(use_cache=True)
|
||||
|
||||
assert registry._cache is not None
|
||||
assert len(runners1) == len(runners2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_handles_plugin_disabled(self):
|
||||
"""Discovery returns empty when plugin system disabled."""
|
||||
ap = FakeApplication()
|
||||
ap.plugin_connector.is_enable_plugin = False
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
runners = await registry.list_runners(use_cache=False)
|
||||
|
||||
assert runners == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_not_polluted_by_bound_plugins(self):
|
||||
"""Cache should contain ALL runners, not filtered by bound_plugins.
|
||||
|
||||
Regression test: get(bound_plugins=["a/b"]) should not pollute cache,
|
||||
so subsequent list_runners(bound_plugins=None) should return all runners.
|
||||
"""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# First: get with bound_plugins filter (should not pollute cache)
|
||||
descriptor = await registry.get(
|
||||
'plugin:langbot/local-agent/default',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
assert descriptor.id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
# Cache should contain ALL runners (both langbot and alice)
|
||||
assert registry._cache is not None
|
||||
assert len(registry._cache) == 2 # Both runners in cache
|
||||
assert 'plugin:langbot/local-agent/default' in registry._cache
|
||||
assert 'plugin:alice/my-agent/custom' in registry._cache
|
||||
|
||||
# Second: list_runners without filter should return ALL runners
|
||||
all_runners = await registry.list_runners(bound_plugins=None, use_cache=True)
|
||||
assert len(all_runners) == 2 # Both runners returned
|
||||
|
||||
# Third: list_runners with different filter should work correctly
|
||||
alice_runners = await registry.list_runners(bound_plugins=['alice/my-agent'], use_cache=True)
|
||||
assert len(alice_runners) == 1
|
||||
assert alice_runners[0].id == 'plugin:alice/my-agent/custom'
|
||||
|
||||
|
||||
class TestRegistryGet:
|
||||
"""Tests for getting specific runner."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_runner(self):
|
||||
"""Get existing runner by ID."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
descriptor = await registry.get('plugin:langbot/local-agent/default')
|
||||
|
||||
assert descriptor.id == 'plugin:langbot/local-agent/default'
|
||||
assert descriptor.plugin_author == 'langbot'
|
||||
assert descriptor.plugin_name == 'local-agent'
|
||||
assert descriptor.runner_name == 'default'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_runner(self):
|
||||
"""Get nonexistent runner raises RunnerNotFoundError."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
with pytest.raises(RunnerNotFoundError) as exc_info:
|
||||
await registry.get('plugin:notexist/unknown/default')
|
||||
|
||||
assert exc_info.value.runner_id == 'plugin:notexist/unknown/default'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runner_with_bound_plugins_filter(self):
|
||||
"""Get runner with bound plugins authorization."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# Authorized - langbot plugin in bound list
|
||||
descriptor = await registry.get(
|
||||
'plugin:langbot/local-agent/default',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
assert descriptor is not None
|
||||
|
||||
# Not authorized - plugin not in bound list
|
||||
with pytest.raises(RunnerNotAuthorizedError):
|
||||
await registry.get(
|
||||
'plugin:alice/my-agent/custom',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
|
||||
|
||||
class TestRegistryMetadataForPipeline:
|
||||
"""Tests for get_runner_metadata_for_pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_options_and_stages(self):
|
||||
"""Get metadata options and stages for pipeline UI."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
options, stages = await registry.get_runner_metadata_for_pipeline()
|
||||
|
||||
# Should have options for each runner
|
||||
assert len(options) == 2
|
||||
option_ids = [o['name'] for o in options]
|
||||
assert 'plugin:langbot/local-agent/default' in option_ids
|
||||
assert 'plugin:alice/my-agent/custom' in option_ids
|
||||
|
||||
# Config comes from the typed manifest.
|
||||
assert len(stages) == 1
|
||||
assert stages[0]['name'] == 'plugin:alice/my-agent/custom'
|
||||
assert stages[0]['config'][0]['name'] == 'param1'
|
||||
assert stages[0]['config'][0]['type'] == 'string'
|
||||
assert stages[0]['config'][0]['id'] == 'plugin:alice/my-agent/custom.param1'
|
||||
|
||||
|
||||
class TestDescriptorValidation:
|
||||
"""Tests for descriptor validation."""
|
||||
|
||||
def test_validate_runner_descriptor(self):
|
||||
"""Validate correctly built descriptor."""
|
||||
descriptor = AgentRunnerDescriptor(
|
||||
id='plugin:test/my-runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert descriptor.id == 'plugin:test/my-runner/default'
|
||||
assert descriptor.get_plugin_id() == 'test/my-runner'
|
||||
assert 'protocol_version' not in AgentRunnerDescriptor.model_fields
|
||||
|
||||
def test_descriptor_capabilities(self):
|
||||
"""Descriptor capability helper methods."""
|
||||
descriptor = AgentRunnerDescriptor(
|
||||
id='plugin:test/my-runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True, 'tool_calling': False},
|
||||
)
|
||||
|
||||
assert descriptor.supports_streaming() is True
|
||||
assert descriptor.supports_tool_calling() is False
|
||||
assert descriptor.supports_knowledge_retrieval() is False
|
||||
@@ -0,0 +1,400 @@
|
||||
"""Tests for AgentResourceBuilder."""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.binding_resolver import AgentBindingResolver
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
|
||||
|
||||
|
||||
RUNNER_ID = 'plugin:test/runner/default'
|
||||
FULL_PERMISSIONS = {
|
||||
'models': ['invoke', 'stream', 'rerank'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'storage': ['plugin', 'workspace'],
|
||||
}
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
*,
|
||||
config_schema: list[dict] | None = None,
|
||||
capabilities: dict | None = None,
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id=RUNNER_ID,
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
capabilities=capabilities or {},
|
||||
permissions=permissions if permissions is not None else FULL_PERMISSIONS,
|
||||
config_schema=config_schema or [],
|
||||
)
|
||||
|
||||
|
||||
def make_model(model_type='llm', provider='test-provider'):
|
||||
return SimpleNamespace(
|
||||
model_entity=SimpleNamespace(model_type=model_type),
|
||||
provider_entity=SimpleNamespace(name=provider),
|
||||
)
|
||||
|
||||
|
||||
def make_query(
|
||||
runner_config: dict,
|
||||
*,
|
||||
variables: dict | None = None,
|
||||
use_llm_model_uuid=None,
|
||||
use_funcs: list | None = None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
query_id=1,
|
||||
bot_uuid='bot_001',
|
||||
launcher_type='person',
|
||||
launcher_id='launcher_001',
|
||||
sender_id='sender_001',
|
||||
message_event=None,
|
||||
message_chain=None,
|
||||
user_message=None,
|
||||
session=None,
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'id': RUNNER_ID},
|
||||
'runner_config': {RUNNER_ID: runner_config},
|
||||
},
|
||||
},
|
||||
variables=variables or {},
|
||||
use_llm_model_uuid=use_llm_model_uuid,
|
||||
use_funcs=use_funcs or [],
|
||||
pipeline_uuid='pipeline_001',
|
||||
)
|
||||
|
||||
|
||||
async def build_resources(app, query, descriptor):
|
||||
event = QueryEntryAdapter.query_to_event(query)
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(query, descriptor.id)
|
||||
binding = AgentBindingResolver().resolve_one(event, [agent_config])
|
||||
return await AgentResourceBuilder(app).build_resources_from_binding(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.model_mgr = Mock()
|
||||
mock_app.rag_mgr = Mock()
|
||||
mock_app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(return_value=None)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app):
|
||||
"""DynamicForm model selectors should become run-scoped authorized models."""
|
||||
llm_models = {
|
||||
'primary': make_model(),
|
||||
'fallback': make_model(),
|
||||
'aux': make_model(provider='aux-provider'),
|
||||
}
|
||||
rerank_models = {
|
||||
'rerank': make_model(model_type='rerank', provider='rerank-provider'),
|
||||
}
|
||||
|
||||
async def get_model_by_uuid(model_uuid):
|
||||
return llm_models.get(model_uuid)
|
||||
|
||||
async def get_rerank_model_by_uuid(model_uuid):
|
||||
return rerank_models.get(model_uuid)
|
||||
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'aux-model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': {'primary': 'primary', 'fallbacks': ['fallback', 'primary']},
|
||||
'aux-model': 'aux',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'primary', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'fallback', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'aux', 'model_type': 'llm', 'provider': 'aux-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_from_config_without_manifest_acl(app):
|
||||
"""Config-selected models are not projected without manifest model permissions."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=make_model(model_type='rerank'))
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={},
|
||||
)
|
||||
query = make_query({
|
||||
'model': {'primary': 'primary', 'fallbacks': ['fallback']},
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
|
||||
"""Config-selected model references are projected regardless of method granularity."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_resources_accepts_dynamic_form_type_aliases(app):
|
||||
"""Frontend DynamicForm aliases should resolve to runtime resource grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
|
||||
async def get_kb(kb_uuid):
|
||||
return SimpleNamespace(
|
||||
uuid=kb_uuid,
|
||||
get_name=lambda: f'name-{kb_uuid}',
|
||||
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||
)
|
||||
|
||||
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'select-llm-model'},
|
||||
{'name': 'knowledge-bases', 'type': 'select-knowledge-bases'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm_alias',
|
||||
'knowledge-bases': ['kb_alias'],
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
]
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_manifest_permission_narrows_binding(app):
|
||||
"""Manifest model permissions narrower than binding should remove LLM grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'models': ['rerank'],
|
||||
},
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_deduplicates_query_and_config_models(app):
|
||||
"""A model selected by both preproc and runner config should appear once."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=None)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'model': {'primary': 'primary', 'fallbacks': ['fallback']}},
|
||||
variables={'_fallback_model_uuids': ['fallback']},
|
||||
use_llm_model_uuid='primary',
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert [model['model_id'] for model in resources['models']] == ['primary', 'fallback']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_tools_authorizes_query_declared_tools(app):
|
||||
"""Tools discovered by Pipeline preprocessing become run-scoped authorized resources."""
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'tool_calling': True},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
use_funcs=[
|
||||
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
|
||||
SimpleNamespace(name='qa_mcp_echo'),
|
||||
],
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['tools'] == [
|
||||
{
|
||||
'tool_name': 'qa_plugin_echo',
|
||||
'tool_type': None,
|
||||
'description': None,
|
||||
'operations': ['detail', 'call'],
|
||||
},
|
||||
{
|
||||
'tool_name': 'qa_mcp_echo',
|
||||
'tool_type': None,
|
||||
'description': None,
|
||||
'operations': ['detail', 'call'],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_tools_manifest_permission_denies_binding_tools(app):
|
||||
"""Binding tool grants should be removed when manifest does not request tools."""
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'tool_calling': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'tools': [],
|
||||
},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
use_funcs=[
|
||||
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
|
||||
],
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['tools'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'knowledge-bases': ['kb_config']},
|
||||
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||
)
|
||||
|
||||
async def get_kb(kb_uuid):
|
||||
return SimpleNamespace(
|
||||
uuid=kb_uuid,
|
||||
get_name=lambda: f'name-{kb_uuid}',
|
||||
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||
)
|
||||
|
||||
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
{'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_manifest_permission_denies_binding_kbs(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'knowledge_bases': [],
|
||||
},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'knowledge-bases': ['kb_config']},
|
||||
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['knowledge_bases'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_storage_intersects_manifest_and_binding_policy(app):
|
||||
descriptor = make_descriptor(
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
query = make_query({})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['storage'] == {
|
||||
'plugin_storage': True,
|
||||
'workspace_storage': False,
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Tests for agent runner result normalizer."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.result_normalizer import AgentResultNormalizer
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.errors import RunnerExecutionError, RunnerProtocolError
|
||||
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self):
|
||||
class FakeLogger:
|
||||
def __init__(self):
|
||||
self.warnings = []
|
||||
|
||||
def info(self, msg):
|
||||
pass
|
||||
def debug(self, msg):
|
||||
pass
|
||||
def warning(self, msg):
|
||||
self.warnings.append(msg)
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
self.logger = FakeLogger()
|
||||
|
||||
|
||||
def make_descriptor():
|
||||
"""Create a test descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id='plugin:langbot/local-agent/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Local Agent', 'zh_Hans': '内置 Agent'},
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True},
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeMessageDelta:
|
||||
"""Tests for normalizing message.delta results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_delta_text(self):
|
||||
"""Normalize message.delta with text chunk."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.delta',
|
||||
'data': {
|
||||
'chunk': {
|
||||
'role': 'assistant',
|
||||
'content': 'Hello',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.MessageChunk)
|
||||
assert result.role == 'assistant'
|
||||
assert result.content == 'Hello'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_delta_missing_chunk(self):
|
||||
"""Invalid message.delta payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.delta',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeMessageCompleted:
|
||||
"""Tests for normalizing message.completed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_completed(self):
|
||||
"""Normalize message.completed with full message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Complete response',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.Message)
|
||||
assert result.role == 'assistant'
|
||||
assert result.content == 'Complete response'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_completed_missing_message(self):
|
||||
"""Invalid message.completed payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeRunCompleted:
|
||||
"""Tests for normalizing run.completed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_completed_with_message(self):
|
||||
"""Normalize run.completed with final message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Final response',
|
||||
},
|
||||
'finish_reason': 'stop',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.Message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_completed_without_message(self):
|
||||
"""Normalize run.completed without message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.completed',
|
||||
'data': {
|
||||
'finish_reason': 'stop',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeRunFailed:
|
||||
"""Tests for normalizing run.failed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_failed(self):
|
||||
"""Normalize run.failed raises RunnerExecutionError."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.failed',
|
||||
'data': {
|
||||
'error': 'Upstream timeout',
|
||||
'code': 'upstream.timeout',
|
||||
'retryable': True,
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerExecutionError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert exc_info.value.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert exc_info.value.retryable is True
|
||||
assert 'timeout' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestNormalizeNonMessageResults:
|
||||
"""Tests for normalizing non-message results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_tool_call_started(self):
|
||||
"""Normalize tool.call.started returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'tool.call.started',
|
||||
'data': {
|
||||
'tool_call_id': 'call_1',
|
||||
'tool_name': 'weather',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_tool_call_completed(self):
|
||||
"""Normalize tool.call.completed returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'tool.call.completed',
|
||||
'data': {
|
||||
'tool_call_id': 'call_1',
|
||||
'tool_name': 'weather',
|
||||
'result': {'temp': 20},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_state_updated(self):
|
||||
"""Normalize state.updated returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'state.updated',
|
||||
'data': {
|
||||
'scope': 'conversation',
|
||||
'key': 'external_conversation_id',
|
||||
'value': 'abc123',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_action_requested(self):
|
||||
"""Normalize action.requested returns None (EBA reserved)."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'action.requested',
|
||||
'data': {
|
||||
'action': 'platform.message.edit',
|
||||
'payload': {},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_state_updated_payload_is_dropped(self):
|
||||
"""Invalid state.updated payload returns None with a warning."""
|
||||
app = FakeApplication()
|
||||
normalizer = AgentResultNormalizer(app)
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result = await normalizer.normalize(
|
||||
{
|
||||
'type': 'state.updated',
|
||||
'data': {
|
||||
'scope': 'invalid',
|
||||
'key': 'k',
|
||||
'value': 'v',
|
||||
},
|
||||
},
|
||||
descriptor,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert app.logger.warnings
|
||||
|
||||
class TestNormalizeInvalidResults:
|
||||
"""Tests for handling invalid results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_missing_type(self):
|
||||
"""Normalize result without type."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'data': {},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert 'Missing result type' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_unknown_type(self):
|
||||
"""Normalize unknown type returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'unknown_type',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_legacy_type_returns_none(self):
|
||||
"""Legacy types (chunk, text, finish) are now treated as unknown."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
# chunk is now unknown
|
||||
result_dict = {
|
||||
'type': 'chunk',
|
||||
'data': {
|
||||
'message_chunk': {
|
||||
'role': 'assistant',
|
||||
'content': 'Legacy chunk',
|
||||
},
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
# text is now unknown
|
||||
result_dict = {
|
||||
'type': 'text',
|
||||
'data': {
|
||||
'content': 'Legacy text',
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
# finish is now unknown
|
||||
result_dict = {
|
||||
'type': 'finish',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Legacy finish',
|
||||
},
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,430 @@
|
||||
"""Tests for RunLedgerStore host primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore
|
||||
from langbot.pkg.entity.persistence.agent_run import AgentRun
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(tmp_path):
|
||||
db_path = tmp_path / 'run_ledger_store.db'
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(db_engine):
|
||||
return RunLedgerStore(db_engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_queued_run_claim_renew_release(store):
|
||||
run = await store.create_run(
|
||||
run_id='run-queued',
|
||||
event_id='evt-1',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=10,
|
||||
requested_runtime_id='runtime-a',
|
||||
)
|
||||
|
||||
assert run['status'] == 'queued'
|
||||
assert run['started_at'] is None
|
||||
assert run['queue_name'] == 'default'
|
||||
assert run['priority'] == 10
|
||||
assert run['requested_runtime_id'] == 'runtime-a'
|
||||
|
||||
assert await store.claim_next_run(runtime_id='runtime-b', queue_name='default') is None
|
||||
|
||||
claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=30)
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed['run_id'] == 'run-queued'
|
||||
assert claimed['status'] == 'claimed'
|
||||
assert claimed['claimed_by_runtime_id'] == 'runtime-a'
|
||||
assert claimed['claim_token']
|
||||
assert claimed['dispatch_attempts'] == 1
|
||||
assert claimed['claim_lease_expires_at'] is not None
|
||||
assert claimed['last_claimed_at'] is not None
|
||||
|
||||
token = claimed['claim_token']
|
||||
assert await store.renew_claim(run_id='run-queued', claim_token='wrong-token') is None
|
||||
|
||||
renewed = await store.renew_claim(run_id='run-queued', claim_token=token, lease_seconds=90)
|
||||
|
||||
assert renewed is not None
|
||||
assert 'claim_token' not in renewed
|
||||
assert renewed['claim_lease_expires_at'] >= claimed['claim_lease_expires_at']
|
||||
|
||||
released = await store.release_claim(
|
||||
run_id='run-queued',
|
||||
claim_token=token,
|
||||
status='queued',
|
||||
status_reason='runtime released capacity',
|
||||
)
|
||||
|
||||
assert released is not None
|
||||
assert released['status'] == 'queued'
|
||||
assert released['status_reason'] == 'runtime released capacity'
|
||||
assert released['claimed_by_runtime_id'] is None
|
||||
assert 'claim_token' not in released
|
||||
assert released['claim_lease_expires_at'] is None
|
||||
assert released['dispatch_attempts'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_next_run_applies_scope_filters(store):
|
||||
await store.create_run(
|
||||
run_id='run-other-runner',
|
||||
event_id='evt-other-runner',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-b',
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=30,
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-other-conversation',
|
||||
event_id='evt-other-conversation',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
conversation_id='conv-b',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=20,
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-allowed',
|
||||
event_id='evt-allowed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=10,
|
||||
)
|
||||
|
||||
claimed = await store.claim_next_run(
|
||||
runtime_id='runtime-a',
|
||||
queue_name='default',
|
||||
runner_ids=['runner-a'],
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
thread_id=None,
|
||||
strict_thread=True,
|
||||
)
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed['run_id'] == 'run-allowed'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_claim_can_be_reclaimed(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-expired',
|
||||
event_id='evt-2',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
first_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
assert first_claim is not None
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-expired')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
reclaimed = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60)
|
||||
|
||||
assert reclaimed is not None
|
||||
assert reclaimed['run_id'] == 'run-expired'
|
||||
assert reclaimed['claimed_by_runtime_id'] == 'runtime-b'
|
||||
assert reclaimed['claim_token'] != first_claim['claim_token']
|
||||
assert reclaimed['dispatch_attempts'] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_expired_claims_requeues_runs(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-expired-release',
|
||||
event_id='evt-3',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-active-claim',
|
||||
event_id='evt-4',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
expired_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
active_claim = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60)
|
||||
assert expired_claim is not None
|
||||
assert active_claim is not None
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-expired-release')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
released = await store.release_expired_claims()
|
||||
|
||||
assert [run['run_id'] for run in released] == ['run-expired-release']
|
||||
assert released[0]['status'] == 'queued'
|
||||
assert released[0]['status_reason'] == 'claim lease expired'
|
||||
assert released[0]['claimed_by_runtime_id'] is None
|
||||
assert 'claim_token' not in released[0]
|
||||
assert released[0]['claim_lease_expires_at'] is None
|
||||
|
||||
active = await store.get_run('run-active-claim')
|
||||
assert active is not None
|
||||
assert active['status'] == 'claimed'
|
||||
assert active['claimed_by_runtime_id'] == active_claim['claimed_by_runtime_id']
|
||||
assert 'claim_token' not in active
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_claim_cannot_renew_or_release(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-stale-claim',
|
||||
event_id='evt-stale',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
assert claimed is not None
|
||||
token = claimed['claim_token']
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-stale-claim')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert await store.renew_claim(run_id='run-stale-claim', claim_token=token, runtime_id='runtime-a') is None
|
||||
assert await store.release_claim(run_id='run-stale-claim', claim_token=token, runtime_id='runtime-a') is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_status_validation_and_terminal_transition_rules(store):
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.create_run(
|
||||
run_id='run-invalid-create',
|
||||
event_id='evt-invalid',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='bogus',
|
||||
)
|
||||
|
||||
await store.create_run(
|
||||
run_id='run-invalid-release',
|
||||
event_id='evt-release',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default')
|
||||
assert claim is not None
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.release_claim(
|
||||
run_id='run-invalid-release',
|
||||
claim_token=claim['claim_token'],
|
||||
runtime_id='runtime-a',
|
||||
status='bogus',
|
||||
)
|
||||
|
||||
await store.create_run(
|
||||
run_id='run-terminal',
|
||||
event_id='evt-terminal',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
)
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.finalize_run(run_id='run-terminal', status='bogus')
|
||||
|
||||
completed = await store.finalize_run(
|
||||
run_id='run-terminal',
|
||||
status='completed',
|
||||
metadata={'attempt': 1},
|
||||
)
|
||||
assert completed is not None
|
||||
assert completed['status'] == 'completed'
|
||||
|
||||
merged = await store.finalize_run(
|
||||
run_id='run-terminal',
|
||||
status='completed',
|
||||
metadata={'retry_observed': True},
|
||||
)
|
||||
assert merged is not None
|
||||
assert merged['metadata'] == {'attempt': 1, 'retry_observed': True}
|
||||
|
||||
with pytest.raises(ValueError, match='Cannot transition terminal run'):
|
||||
await store.finalize_run(run_id='run-terminal', status='failed')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_audit_event_uses_next_sequence(store):
|
||||
await store.create_run(
|
||||
run_id='run-audit',
|
||||
event_id='evt-5',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
)
|
||||
await store.append_event(
|
||||
run_id='run-audit',
|
||||
sequence=1,
|
||||
event_type='message.completed',
|
||||
data={'ok': True},
|
||||
)
|
||||
|
||||
event = await store.append_audit_event(
|
||||
run_id='run-audit',
|
||||
event_type='admin.run_cancel',
|
||||
data={'action': 'run_cancel'},
|
||||
metadata={'permission': 'agent_run:admin'},
|
||||
)
|
||||
|
||||
assert event is not None
|
||||
assert event['sequence'] == 2
|
||||
assert event['type'] == 'admin.run_cancel'
|
||||
assert event['source'] == 'host'
|
||||
assert event['data'] == {'action': 'run_cancel'}
|
||||
assert event['metadata'] == {'permission': 'agent_run:admin'}
|
||||
assert await store.append_audit_event(run_id='missing', event_type='admin.missing') is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_register_heartbeat_list_and_mark_stale(store):
|
||||
registered = await store.register_runtime(
|
||||
runtime_id='runtime-a',
|
||||
display_name='Runtime A',
|
||||
endpoint='http://runtime-a',
|
||||
version='1.0.0',
|
||||
capabilities={'stream': True},
|
||||
labels={'region': 'test'},
|
||||
metadata={'slot_count': 2},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert registered['runtime_id'] == 'runtime-a'
|
||||
assert registered['status'] == 'online'
|
||||
assert registered['display_name'] == 'Runtime A'
|
||||
assert registered['capabilities'] == {'stream': True}
|
||||
assert registered['labels'] == {'region': 'test'}
|
||||
assert registered['metadata'] == {'slot_count': 2}
|
||||
assert registered['last_heartbeat_at'] is not None
|
||||
assert registered['heartbeat_deadline_at'] is not None
|
||||
|
||||
heartbeat = await store.heartbeat_runtime(
|
||||
runtime_id='runtime-a',
|
||||
metadata={'active_runs': 1},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert heartbeat is not None
|
||||
assert heartbeat['metadata'] == {'slot_count': 2, 'active_runs': 1}
|
||||
|
||||
runtimes, total_count = await store.list_runtimes(statuses=['online'])
|
||||
assert [runtime['runtime_id'] for runtime in runtimes] == ['runtime-a']
|
||||
assert total_count == 1
|
||||
|
||||
stale = await store.mark_stale_runtimes(
|
||||
now=datetime.datetime.now(UTC) + datetime.timedelta(seconds=31),
|
||||
)
|
||||
|
||||
assert [runtime['runtime_id'] for runtime in stale] == ['runtime-a']
|
||||
assert stale[0]['status'] == 'stale'
|
||||
assert (await store.get_runtime('runtime-a'))['status'] == 'stale'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_stats_splits_active_and_claimed_runs(store):
|
||||
await store.register_runtime(runtime_id='runtime-a')
|
||||
await store.create_run(
|
||||
run_id='run-running',
|
||||
event_id='evt-running',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='running',
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-claimed',
|
||||
event_id='evt-claimed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
assert await store.claim_next_run(runtime_id='runtime-a', queue_name='default') is not None
|
||||
|
||||
stats = await store.get_runtime_stats()
|
||||
|
||||
assert stats['active_runs'] == 2
|
||||
assert stats['claimed_runs'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_stats_reports_zero_success_rate_for_failed_only_runner(store):
|
||||
now = int(datetime.datetime.now(UTC).timestamp())
|
||||
await store.create_run(
|
||||
run_id='run-failed',
|
||||
event_id='evt-failed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='failed',
|
||||
)
|
||||
|
||||
stats = await store.get_runner_stats(start_time=now - 10, end_time=now + 10)
|
||||
|
||||
assert stats[0]['runner_id'] == 'runner-a'
|
||||
assert stats[0]['failed_runs'] == 1
|
||||
assert stats[0]['success_rate'] == 0.0
|
||||
@@ -0,0 +1,633 @@
|
||||
"""Tests for AgentRunSessionRegistry."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from langbot.pkg.agent.runner.session_registry import (
|
||||
AgentRunSessionRegistry,
|
||||
AgentRunSession,
|
||||
MAX_STEERING_QUEUE_ITEMS,
|
||||
get_session_registry,
|
||||
)
|
||||
|
||||
# Import shared test fixtures from conftest.py
|
||||
from .conftest import make_resources, make_session
|
||||
|
||||
|
||||
class TestSessionRegistryBasic:
|
||||
"""Tests for basic registry operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_get(self):
|
||||
"""Register and retrieve a session."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_abc'
|
||||
resources = make_resources(
|
||||
models=[{'model_id': 'model_001', 'model_type': 'chat', 'provider': 'openai'}],
|
||||
tools=[{'tool_name': 'web_search', 'tool_type': 'builtin'}],
|
||||
)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
result = await registry.get(run_id)
|
||||
assert result is not None
|
||||
assert result['run_id'] == run_id
|
||||
assert result['runner_id'] == 'plugin:test/my-runner/default'
|
||||
assert result['query_id'] == 1
|
||||
assert result['plugin_identity'] == 'test/my-runner'
|
||||
auth_resources = result['authorization']['resources']
|
||||
assert len(auth_resources['models']) == 1
|
||||
assert auth_resources['models'][0]['model_id'] == 'model_001'
|
||||
assert 'resources' not in result
|
||||
assert 'permissions' not in result
|
||||
assert '_authorized_ids' not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_plugin_identity(self):
|
||||
"""Agent run sessions must always have an owning plugin identity."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
with pytest.raises(ValueError, match='plugin_identity is required'):
|
||||
await registry.register(
|
||||
run_id='run_missing_identity',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='',
|
||||
resources=make_resources(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_freezes_authorization_snapshot(self):
|
||||
"""Register should freeze authorization data for the run."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[{'model_id': 'model_001'}],
|
||||
storage={'plugin_storage': True, 'workspace_storage': False},
|
||||
)
|
||||
|
||||
await registry.register(
|
||||
run_id='run_snapshot',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=resources,
|
||||
available_apis={'history_page': True},
|
||||
conversation_id='conv_001',
|
||||
)
|
||||
|
||||
resources['models'].append({'model_id': 'model_late'})
|
||||
resources['storage']['workspace_storage'] = True
|
||||
|
||||
session = await registry.get('run_snapshot')
|
||||
assert session is not None
|
||||
authorization = session['authorization']
|
||||
assert authorization['conversation_id'] == 'conv_001'
|
||||
assert authorization['available_apis'] == {'history_page': True}
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_late') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_session(self):
|
||||
"""Get should return None for nonexistent run_id."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
result = await registry.get('nonexistent_run')
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister(self):
|
||||
"""Unregister should remove session."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_xyz'
|
||||
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
)
|
||||
|
||||
# Verify registered
|
||||
result = await registry.get(run_id)
|
||||
assert result is not None
|
||||
|
||||
# Unregister
|
||||
await registry.unregister(run_id)
|
||||
|
||||
# Verify unregistered
|
||||
result = await registry.get(run_id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_nonexistent(self):
|
||||
"""Unregister nonexistent session should not raise error."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
# Should not raise
|
||||
await registry.unregister('nonexistent_run')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_activity(self):
|
||||
"""Update activity should update last_activity_at."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_activity'
|
||||
|
||||
# Create session with manually set old timestamp
|
||||
now = int(time.time())
|
||||
old_session: AgentRunSession = make_session(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
)
|
||||
old_session['status'] = {
|
||||
'started_at': now - 100,
|
||||
'last_activity_at': now - 100,
|
||||
}
|
||||
|
||||
async with registry._lock:
|
||||
registry._sessions[run_id] = old_session
|
||||
|
||||
# Get initial session
|
||||
session1 = await registry.get(run_id)
|
||||
initial_time = session1['status']['last_activity_at']
|
||||
|
||||
# Update activity
|
||||
await registry.update_activity(run_id)
|
||||
|
||||
# Verify updated - should be significantly different (100 seconds)
|
||||
session2 = await registry.get(run_id)
|
||||
assert session2['status']['last_activity_at'] > initial_time
|
||||
assert session2['status']['last_activity_at'] - initial_time >= 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_activity_nonexistent(self):
|
||||
"""Update activity on nonexistent session should not raise."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
# Should not raise
|
||||
await registry.update_activity('nonexistent_run')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_active_runs(self):
|
||||
"""List active runs should return all sessions."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
await registry.register('run_1', 'plugin:a/b/default', 1, 'a/b', make_resources())
|
||||
await registry.register('run_2', 'plugin:c/d/default', 2, 'c/d', make_resources())
|
||||
|
||||
active_runs = await registry.list_active_runs()
|
||||
assert len(active_runs) == 2
|
||||
run_ids = [r['run_id'] for r in active_runs]
|
||||
assert 'run_1' in run_ids
|
||||
assert 'run_2' in run_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_stale_sessions(self):
|
||||
"""Cleanup should remove old sessions."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Create sessions with manually set old timestamp
|
||||
now = int(time.time())
|
||||
old_session: AgentRunSession = make_session(
|
||||
run_id='old_run',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
)
|
||||
old_session['status'] = {
|
||||
'started_at': now - 7200,
|
||||
'last_activity_at': now - 7200,
|
||||
}
|
||||
new_session: AgentRunSession = make_session(
|
||||
run_id='new_run',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=2,
|
||||
plugin_identity='test/runner',
|
||||
)
|
||||
new_session['status'] = {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
}
|
||||
|
||||
async with registry._lock:
|
||||
registry._sessions['old_run'] = old_session
|
||||
registry._sessions['new_run'] = new_session
|
||||
|
||||
# Cleanup sessions older than 1 hour
|
||||
cleaned = await registry.cleanup_stale_sessions(max_age_seconds=3600)
|
||||
assert cleaned == 1
|
||||
|
||||
# Verify old session removed, new remains
|
||||
assert await registry.get('old_run') is None
|
||||
assert await registry.get('new_run') is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_steering_all_preserves_queue_order(self):
|
||||
"""Default all-mode steering returns every queued item in FIFO order."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_1'}, 'input': {'text': 'first'}})
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_2'}, 'input': {'text': 'second'}})
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_3'}, 'input': {'text': 'third'}})
|
||||
|
||||
items = await registry.pull_steering('run_steering', mode='all')
|
||||
assert [item['event']['event_id'] for item in items] == ['event_1', 'event_2', 'event_3']
|
||||
assert await registry.pull_steering('run_steering', mode='all') == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_steering_one_at_a_time_leaves_remaining_items(self):
|
||||
"""one-at-a-time is an explicit runner-side throttling mode."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_one',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
await registry.enqueue_steering('run_steering_one', {'event': {'event_id': 'event_1'}})
|
||||
await registry.enqueue_steering('run_steering_one', {'event': {'event_id': 'event_2'}})
|
||||
|
||||
first = await registry.pull_steering('run_steering_one', mode='one-at-a-time')
|
||||
second = await registry.pull_steering('run_steering_one', mode='one-at-a-time')
|
||||
|
||||
assert [item['event']['event_id'] for item in first] == ['event_1']
|
||||
assert [item['event']['event_id'] for item in second] == ['event_2']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_steering_rejects_when_queue_is_full(self):
|
||||
"""A full steering queue does not claim more queries."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_full',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
for index in range(MAX_STEERING_QUEUE_ITEMS):
|
||||
assert await registry.enqueue_steering(
|
||||
'run_steering_full',
|
||||
{'event': {'event_id': f'event_{index}'}},
|
||||
)
|
||||
|
||||
assert not await registry.enqueue_steering(
|
||||
'run_steering_full',
|
||||
{'event': {'event_id': 'overflow'}},
|
||||
)
|
||||
|
||||
items = await registry.pull_steering('run_steering_full', mode='all')
|
||||
assert len(items) == MAX_STEERING_QUEUE_ITEMS
|
||||
assert all(item['event']['event_id'] != 'overflow' for item in items)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_steering_target_requires_same_scope(self):
|
||||
"""Steering claims must not cross bot/workspace/thread boundaries."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_scoped',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) == 'run_steering_scoped'
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_2',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) is None
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_2',
|
||||
) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_returns_pending_steering_queue(self):
|
||||
"""Unregister returns the removed session so callers can audit pending steering."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_unregister',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
await registry.enqueue_steering(
|
||||
'run_steering_unregister',
|
||||
{'event': {'event_id': 'event_pending'}},
|
||||
)
|
||||
|
||||
session = await registry.unregister('run_steering_unregister')
|
||||
|
||||
assert session is not None
|
||||
assert session['steering_queue'][0]['event']['event_id'] == 'event_pending'
|
||||
assert await registry.get('run_steering_unregister') is None
|
||||
|
||||
|
||||
class TestIsResourceAllowed:
|
||||
"""Tests for is_resource_allowed validation."""
|
||||
|
||||
def test_model_allowed(self):
|
||||
"""Model in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[
|
||||
{'model_id': 'model_001', 'model_type': 'chat', 'provider': 'openai'},
|
||||
{'model_id': 'model_002', 'model_type': 'embedding', 'provider': 'anthropic'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_002') is True
|
||||
|
||||
def test_model_operation_denied(self):
|
||||
"""Model resources should enforce operation-level grants."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[
|
||||
{'model_id': 'model_001', 'operations': ['invoke']},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001', 'invoke') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001', 'stream') is False
|
||||
|
||||
def test_model_not_allowed(self):
|
||||
"""Model not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(models=[{'model_id': 'model_001'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_999') is False
|
||||
|
||||
def test_model_empty_resources(self):
|
||||
"""Empty models list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(models=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is False
|
||||
|
||||
def test_tool_allowed(self):
|
||||
"""Tool in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
tools=[
|
||||
{'tool_name': 'web_search', 'tool_type': 'builtin'},
|
||||
{'tool_name': 'code_exec', 'tool_type': 'plugin'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is True
|
||||
assert registry.is_resource_allowed(session, 'tool', 'code_exec') is True
|
||||
|
||||
def test_tool_operation_denied(self):
|
||||
"""Tool resources should enforce detail/call grants."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
tools=[
|
||||
{'tool_name': 'web_search', 'operations': ['detail']},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search', 'detail') is True
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search', 'call') is False
|
||||
|
||||
def test_tool_not_allowed(self):
|
||||
"""Tool not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(tools=[{'tool_name': 'web_search'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'image_gen') is False
|
||||
|
||||
def test_tool_empty_resources(self):
|
||||
"""Empty tools list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(tools=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is False
|
||||
|
||||
def test_knowledge_base_allowed(self):
|
||||
"""Knowledge base in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
knowledge_bases=[
|
||||
{'kb_id': 'kb_001', 'kb_name': 'docs', 'kb_type': 'vector'},
|
||||
{'kb_id': 'kb_002', 'kb_name': 'faq', 'kb_type': 'keyword'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is True
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_002') is True
|
||||
|
||||
def test_knowledge_base_not_allowed(self):
|
||||
"""Knowledge base not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(knowledge_bases=[{'kb_id': 'kb_001'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_999') is False
|
||||
|
||||
def test_knowledge_base_empty_resources(self):
|
||||
"""Empty knowledge bases list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(knowledge_bases=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is False
|
||||
|
||||
def test_skill_allowed(self):
|
||||
"""Skill in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
skills=[
|
||||
{'skill_name': 'demo', 'display_name': 'Demo'},
|
||||
{'skill_name': 'writer', 'display_name': 'Writer'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'skill', 'demo') is True
|
||||
assert registry.is_resource_allowed(session, 'skill', 'writer') is True
|
||||
assert registry.is_resource_allowed(session, 'skill', 'hidden') is False
|
||||
|
||||
def test_storage_plugin_allowed(self):
|
||||
"""Plugin storage permission should be checked."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': True, 'workspace_storage': False})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is True
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
def test_storage_workspace_allowed(self):
|
||||
"""Workspace storage permission should be checked."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': False, 'workspace_storage': True})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is True
|
||||
|
||||
def test_storage_both_denied(self):
|
||||
"""Both storage permissions denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': False, 'workspace_storage': False})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
def test_unknown_resource_type(self):
|
||||
"""Unknown resource type should return False."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
session = make_session(resources=make_resources())
|
||||
|
||||
assert registry.is_resource_allowed(session, 'unknown_type', 'something') is False
|
||||
|
||||
def test_missing_resources_field(self):
|
||||
"""Missing resources field should not raise."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
session = make_session(resources={'models': []}) # Missing other fields
|
||||
|
||||
# Should not raise, should return False
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is False
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is False
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""Tests for global registry singleton."""
|
||||
|
||||
def test_get_session_registry_returns_instance(self):
|
||||
"""get_session_registry should return AgentRunSessionRegistry."""
|
||||
# Use a separate test that doesn't modify global state
|
||||
# The singleton pattern works in production, but modifying globals
|
||||
# in tests can cause UnboundLocalError due to Python scoping
|
||||
# Instead, just verify the function signature
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
assert callable(get_session_registry)
|
||||
|
||||
# Create a fresh instance directly to verify the class works
|
||||
fresh_registry = AgentRunSessionRegistry()
|
||||
assert isinstance(fresh_registry, AgentRunSessionRegistry)
|
||||
|
||||
def test_global_registry_singleton_behavior(self):
|
||||
"""The global registry should maintain singleton behavior."""
|
||||
# Test singleton behavior without modifying global state
|
||||
# In production, calling get_session_registry() multiple times
|
||||
# returns the same instance. We verify this by checking the
|
||||
# module-level variable directly.
|
||||
from langbot.pkg.agent.runner.session_registry import _global_registry
|
||||
|
||||
# Check that the global variable exists and is either None or an instance
|
||||
global_reg = _global_registry
|
||||
if global_reg is None:
|
||||
# First call creates the instance
|
||||
registry1 = get_session_registry()
|
||||
assert isinstance(registry1, AgentRunSessionRegistry)
|
||||
# Subsequent calls return the same instance
|
||||
registry2 = get_session_registry()
|
||||
assert registry1 is registry2
|
||||
else:
|
||||
# Instance already exists, verify singleton
|
||||
registry1 = get_session_registry()
|
||||
registry2 = get_session_registry()
|
||||
assert registry1 is registry2
|
||||
assert registry1 is global_reg
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Tests for asyncio.Lock thread safety."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_register(self):
|
||||
"""Concurrent register should be safe."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Register multiple sessions concurrently
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
tasks.append(
|
||||
registry.register(
|
||||
f'run_{i}',
|
||||
'plugin:test/runner/default',
|
||||
i,
|
||||
'test/runner',
|
||||
make_resources(),
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# All sessions should be registered
|
||||
active_runs = await registry.list_active_runs()
|
||||
assert len(active_runs) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_register_and_unregister(self):
|
||||
"""Concurrent register and unregister should be safe."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Register
|
||||
await registry.register('run_1', 'plugin:test/runner/default', 1, 'test/runner', make_resources())
|
||||
|
||||
# Concurrent unregister and get
|
||||
tasks = [
|
||||
registry.unregister('run_1'),
|
||||
registry.get('run_1'),
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# After both complete, session should be unregistered
|
||||
result = await registry.get('run_1')
|
||||
assert result is None
|
||||
@@ -0,0 +1,544 @@
|
||||
"""Tests for State API handler authorization in RuntimeConnectionHandler.
|
||||
|
||||
Tests focus on:
|
||||
- STATE_GET authorization
|
||||
- STATE_SET authorization
|
||||
- STATE_DELETE authorization
|
||||
- STATE_LIST authorization
|
||||
|
||||
These tests instantiate real RuntimeConnectionHandler action handlers and verify:
|
||||
- Authorization errors for missing/mismatched caller_plugin_identity
|
||||
- Authorization errors for disabled state or scope
|
||||
- Full flow: set -> get -> list -> delete with real SQLite
|
||||
|
||||
Authorization rules:
|
||||
- caller_plugin_identity is REQUIRED when session has plugin_identity
|
||||
- caller_plugin_identity must match session's plugin_identity
|
||||
- enable_state must be True
|
||||
- scope must be in state_scopes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore, reset_persistent_state_store
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
# Import shared test fixtures
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""Fake connection for testing."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self, db_engine=None):
|
||||
self.logger = MagicMock()
|
||||
self.logger.debug = MagicMock()
|
||||
self.logger.warning = MagicMock()
|
||||
self.logger.error = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry():
|
||||
"""Create a fresh session registry for each test."""
|
||||
return AgentRunSessionRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(db_engine):
|
||||
"""Create a persistent state store with real SQLite."""
|
||||
reset_persistent_state_store()
|
||||
store = PersistentStateStore(db_engine)
|
||||
|
||||
# Create the table
|
||||
from langbot.pkg.entity.persistence.agent_runner_state import AgentRunnerState
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.run_sync(AgentRunnerState.__table__.create, checkfirst=True)
|
||||
|
||||
yield store
|
||||
reset_persistent_state_store()
|
||||
|
||||
|
||||
class TestStateAPIHandlerAuthorization:
|
||||
"""Tests for State API handler authorization with real action calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_run_id_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing run_id returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Get the STATE_GET action handler (actions dict is keyed by action value string)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without run_id
|
||||
result = await state_get_handler({'scope': 'conversation', 'key': 'test_key'})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'run_id is required' in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_run_not_found_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: run_id not in session registry returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with non-existent run_id
|
||||
result = await state_get_handler({
|
||||
'run_id': 'nonexistent_run',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not found' in result.message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_caller_plugin_identity_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing caller_plugin_identity when session has plugin_identity returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session with plugin_identity
|
||||
await session_registry.register(
|
||||
run_id='run_test_missing_identity',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_missing_identity',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'caller_plugin_identity is required' in result.message
|
||||
|
||||
await session_registry.unregister('run_test_missing_identity')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_caller_identity_mismatch_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: caller_plugin_identity mismatch returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_mismatch',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with wrong caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_mismatch',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'other/plugin',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'mismatch' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_mismatch')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_enable_state_false_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: enable_state=False returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': False, 'state_scopes': []},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_disabled',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_scope_not_enabled_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: scope not in state_scopes returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_scope_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key', 'actor': 'actor_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Request 'actor' scope which is not in state_scopes
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_scope_disabled',
|
||||
'scope': 'actor',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not enabled' in result.message.lower() or 'scope' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_scope_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_scope_key_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing scope_key in state_context returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_no_scope_key',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, # No scope_keys
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_no_scope_key',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not available' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_no_scope_key')
|
||||
|
||||
|
||||
class TestStateAPIFullFlowWithRealDB:
|
||||
"""Tests for complete State API flow with real SQLite database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_get_list_delete_flow(self, session_registry, db_engine, persistent_store):
|
||||
"""Test complete state flow: set -> get -> list -> delete with real SQLite."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session
|
||||
await session_registry.register(
|
||||
run_id='run_full_flow',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation', 'runner']},
|
||||
state_context={
|
||||
'scope_keys': {
|
||||
'conversation': 'conv:test_runner:binding_1:conv_123',
|
||||
'runner': 'runner:test_runner:binding_1',
|
||||
},
|
||||
'binding_identity': 'binding_1',
|
||||
'conversation_id': 'conv_123',
|
||||
},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Verify session has correct state_context
|
||||
session = await session_registry.get('run_full_flow')
|
||||
assert session is not None
|
||||
state_ctx = session['authorization']['state_context']
|
||||
assert state_ctx is not None, f"state_context is None. Session keys: {list(session.keys())}"
|
||||
assert 'scope_keys' in state_ctx, f"scope_keys not in state_context: {state_ctx}"
|
||||
assert 'conversation' in state_ctx['scope_keys'], f"conversation not in scope_keys: {state_ctx['scope_keys']}"
|
||||
|
||||
# Get handlers (actions dict is keyed by action value string)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
state_list_handler = handler.actions[PluginToRuntimeAction.STATE_LIST.value]
|
||||
state_delete_handler = handler.actions[PluginToRuntimeAction.STATE_DELETE.value]
|
||||
|
||||
# 1. STATE_SET
|
||||
set_result = await state_set_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'value': {'data': 'test_value'},
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert set_result.code == 0
|
||||
assert set_result.data.get('success') is True
|
||||
|
||||
# 2. STATE_GET
|
||||
get_result = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_result.code == 0
|
||||
assert get_result.data.get('value') == {'data': 'test_value'}
|
||||
|
||||
# 3. STATE_LIST
|
||||
list_result = await state_list_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'prefix': 'external.',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert list_result.code == 0
|
||||
keys = list_result.data.get('keys', [])
|
||||
assert 'external.test_key' in keys
|
||||
|
||||
# 4. STATE_DELETE
|
||||
delete_result = await state_delete_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert delete_result.code == 0
|
||||
|
||||
# 5. Verify deleted
|
||||
get_after_delete = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_after_delete.code == 0
|
||||
assert get_after_delete.data.get('value') is None
|
||||
|
||||
await session_registry.unregister('run_full_flow')
|
||||
|
||||
|
||||
class TestStateHandlerReadsFromAuthorizationSnapshot:
|
||||
"""Tests verifying handlers read state_policy/state_context from authorization snapshot."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_policy_from_authorization(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_policy from session['authorization'], not resources."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_policy in the authorization snapshot
|
||||
await session_registry.register(
|
||||
run_id='run_policy_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': False, 'state_scopes': []},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_policy
|
||||
session = await session_registry.get('run_policy_top_level')
|
||||
assert session is not None
|
||||
resources = session['authorization']['resources']
|
||||
assert 'state_policy' not in resources, "resources should NOT contain state_policy"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Should fail because enable_state=False in authorization.state_policy
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_policy_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_policy_top_level')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_context_from_authorization(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_context from session['authorization'], not resources."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_context in the authorization snapshot
|
||||
await session_registry.register(
|
||||
run_id='run_context_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key_xyz'}, 'binding_identity': 'binding_xyz'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_context
|
||||
session = await session_registry.get('run_context_top_level')
|
||||
assert session is not None
|
||||
resources = session['authorization']['resources']
|
||||
assert 'state_context' not in resources, "resources should NOT contain state_context"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
|
||||
# Should use scope_key from authorization.state_context.scope_keys.conversation
|
||||
result = await state_set_handler({
|
||||
'run_id': 'run_context_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'value': 'test_value',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
# Should succeed - scope_key was found in state_context
|
||||
assert result.code == 0
|
||||
|
||||
await session_registry.unregister('run_context_top_level')
|
||||
|
||||
|
||||
class TestResourcesDoesNotContainStateMetadata:
|
||||
"""Tests verifying resources is clean - no state metadata mixed in."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resources_clean_after_register(self, session_registry):
|
||||
"""After register(), only authorization contains resources and state metadata."""
|
||||
resources = make_resources()
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_resources_clean',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=resources,
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
session = await session_registry.get('run_resources_clean')
|
||||
assert session is not None
|
||||
|
||||
# Verify resources is nested under authorization and is clean.
|
||||
assert 'resources' not in session
|
||||
session_resources = session['authorization']['resources']
|
||||
assert 'state_policy' not in session_resources, \
|
||||
"authorization['resources'] should NOT contain state_policy"
|
||||
assert 'state_context' not in session_resources, \
|
||||
"authorization['resources'] should NOT contain state_context"
|
||||
|
||||
assert 'state_policy' in session['authorization']
|
||||
assert 'state_context' in session['authorization']
|
||||
|
||||
await session_registry.unregister('run_resources_clean')
|
||||
@@ -0,0 +1,383 @@
|
||||
"""Tests for persistent AgentRunner state store."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import BindingScope, StatePolicy
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
|
||||
from langbot.pkg.agent.runner.state_scope import (
|
||||
STATE_KEY_ALIASES,
|
||||
VALID_STATE_SCOPES,
|
||||
build_state_context,
|
||||
build_state_scope_key,
|
||||
get_binding_identity,
|
||||
normalize_state_key,
|
||||
)
|
||||
|
||||
|
||||
def make_descriptor(runner_id: str = 'plugin:test/my-runner/default') -> AgentRunnerDescriptor:
|
||||
"""Create a test descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True},
|
||||
)
|
||||
|
||||
|
||||
class FakeActorContext:
|
||||
"""Fake actor context for event testing."""
|
||||
def __init__(self, actor_type: str = 'user', actor_id: str = 'user_123', actor_name: str = 'Test User'):
|
||||
self.actor_type = actor_type
|
||||
self.actor_id = actor_id
|
||||
self.actor_name = actor_name
|
||||
|
||||
|
||||
class FakeSubjectContext:
|
||||
"""Fake subject context for event testing."""
|
||||
def __init__(self, subject_type: str = 'message', subject_id: str = 'msg_001', data: dict | None = None):
|
||||
self.subject_type = subject_type
|
||||
self.subject_id = subject_id
|
||||
self.data = data or {}
|
||||
|
||||
|
||||
class FakeEventEnvelope:
|
||||
"""Fake event envelope for testing event-first state."""
|
||||
def __init__(
|
||||
self,
|
||||
event_id: str = 'evt_001',
|
||||
event_type: str = 'message.received',
|
||||
conversation_id: str | None = 'conv_001',
|
||||
actor: FakeActorContext | None = None,
|
||||
subject: FakeSubjectContext | None = None,
|
||||
bot_id: str = 'bot_001',
|
||||
workspace_id: str = 'ws_001',
|
||||
thread_id: str | None = None,
|
||||
):
|
||||
self.event_id = event_id
|
||||
self.event_type = event_type
|
||||
self.event_time = 1700000000
|
||||
self.source = 'platform'
|
||||
self.bot_id = bot_id
|
||||
self.workspace_id = workspace_id
|
||||
self.conversation_id = conversation_id
|
||||
self.thread_id = thread_id
|
||||
self.actor = actor or FakeActorContext()
|
||||
self.subject = subject
|
||||
self.raw_ref = None
|
||||
|
||||
|
||||
class FakeBinding:
|
||||
"""Fake binding for testing state."""
|
||||
def __init__(
|
||||
self,
|
||||
binding_id: str = 'binding_001',
|
||||
state_policy: StatePolicy | None = None,
|
||||
scope_type: str = 'agent',
|
||||
scope_id: str = 'agent_001',
|
||||
):
|
||||
self.binding_id = binding_id
|
||||
self.scope = BindingScope(scope_type=scope_type, scope_id=scope_id)
|
||||
self.state_policy = state_policy or StatePolicy()
|
||||
|
||||
|
||||
class TestStateScopeHelpers:
|
||||
"""Tests for shared state scope helpers."""
|
||||
|
||||
def test_valid_state_scopes(self):
|
||||
assert VALID_STATE_SCOPES == ('conversation', 'actor', 'subject', 'runner')
|
||||
|
||||
def test_state_key_aliases(self):
|
||||
assert STATE_KEY_ALIASES == {'conversation_id': 'external.conversation_id'}
|
||||
assert normalize_state_key('conversation_id') == 'external.conversation_id'
|
||||
assert normalize_state_key('external.session_id') == 'external.session_id'
|
||||
|
||||
def test_binding_identity_uses_binding_id_first(self):
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
assert get_binding_identity(binding) == 'binding_a'
|
||||
|
||||
def test_binding_identity_falls_back_to_scope(self):
|
||||
binding = FakeBinding(binding_id='', scope_type='workspace', scope_id='ws_001')
|
||||
assert get_binding_identity(binding) == 'workspace:ws_001'
|
||||
|
||||
def test_scope_key_building(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
subject=FakeSubjectContext(subject_id='msg_001'),
|
||||
thread_id='thread_001',
|
||||
)
|
||||
|
||||
keys = {
|
||||
scope: build_state_scope_key(scope, event, binding, descriptor)
|
||||
for scope in VALID_STATE_SCOPES
|
||||
}
|
||||
|
||||
assert keys['conversation'].startswith('conversation:v2:')
|
||||
assert keys['actor'].startswith('actor:v2:')
|
||||
assert keys['subject'].startswith('subject:v2:')
|
||||
assert keys['runner'].startswith('runner:v2:')
|
||||
assert len(set(keys.values())) == len(keys)
|
||||
|
||||
def test_scope_key_missing_identity_returns_none(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding()
|
||||
event = FakeEventEnvelope(conversation_id=None, actor=None, subject=None)
|
||||
|
||||
assert build_state_scope_key('conversation', event, binding, descriptor) is None
|
||||
assert build_state_scope_key('subject', event, binding, descriptor) is None
|
||||
assert build_state_scope_key('runner', event, binding, descriptor) is not None
|
||||
|
||||
def test_build_state_context(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
subject=FakeSubjectContext(subject_id='msg_001'),
|
||||
)
|
||||
|
||||
context = build_state_context(event, binding, descriptor)
|
||||
|
||||
assert context['binding_identity'] == 'binding_a'
|
||||
assert context['conversation_id'] == 'conv_001'
|
||||
assert context['actor_id'] == 'user_001'
|
||||
assert set(context['scope_keys']) == {'conversation', 'actor', 'subject', 'runner'}
|
||||
|
||||
|
||||
class TestPersistentStateStore:
|
||||
"""Tests for persistent database-backed state store."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create a temporary async SQLite database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
os.unlink(db_path)
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(self, db_engine):
|
||||
"""Create a persistent state store for testing."""
|
||||
store = PersistentStateStore(db_engine)
|
||||
yield store
|
||||
await store.clear_all()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_snapshot_empty(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
|
||||
assert snapshot['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
assert snapshot['actor'] == {}
|
||||
assert snapshot['subject'] == {}
|
||||
assert snapshot['runner'] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_and_get(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'test_key', {'nested': 'value'}, None
|
||||
)
|
||||
assert success is True
|
||||
assert error is None
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_first_state_set_uses_upsert(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_concurrent'
|
||||
|
||||
async def set_value(value: int):
|
||||
return await persistent_store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key='external.concurrent',
|
||||
value={'value': value},
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
binding_identity='binding_001',
|
||||
scope='conversation',
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*(set_value(value) for value in range(8)))
|
||||
|
||||
assert all(success is True and error is None for success, error in results)
|
||||
stored = await persistent_store.state_get(scope_key, 'external.concurrent')
|
||||
assert stored in [{'value': value} for value in range(8)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_api_methods_normalize_public_key_aliases(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_001'
|
||||
|
||||
success, error = await persistent_store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key='conversation_id',
|
||||
value='conv_001',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
binding_identity='binding_001',
|
||||
scope='conversation',
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert await persistent_store.state_get(scope_key, 'external.conversation_id') == 'conv_001'
|
||||
assert await persistent_store.state_get(scope_key, 'conversation_id') == 'conv_001'
|
||||
|
||||
keys, _ = await persistent_store.state_list(scope_key, prefix='conversation_id')
|
||||
assert keys == ['external.conversation_id']
|
||||
|
||||
assert await persistent_store.state_delete(scope_key, 'conversation_id') is True
|
||||
assert await persistent_store.state_get(scope_key, 'external.conversation_id') is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_binding_isolation(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding_a = FakeBinding(binding_id='binding_a')
|
||||
binding_b = FakeBinding(binding_id='binding_b')
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding_a, descriptor, 'conversation', 'key', 'value_a', None
|
||||
)
|
||||
|
||||
snapshot_b = await persistent_store.build_snapshot_from_event(event, binding_b, descriptor)
|
||||
assert snapshot_b['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
|
||||
snapshot_a = await persistent_store.build_snapshot_from_event(event, binding_a, descriptor)
|
||||
assert snapshot_a['conversation']['key'] == 'value_a'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_disable_state(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding(state_policy=StatePolicy(enable_state=False))
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot == {'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}}
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
assert success is False
|
||||
assert 'disabled' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_scope_restriction(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
)
|
||||
binding = FakeBinding(state_policy=StatePolicy(state_scopes=['conversation']))
|
||||
|
||||
success_conv, _ = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value_conv', None
|
||||
)
|
||||
assert success_conv is True
|
||||
|
||||
success_actor, error_actor = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'actor', 'key', 'value_actor', None
|
||||
)
|
||||
assert success_actor is False
|
||||
assert 'not enabled' in error_actor.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_json_size_limit(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
large_value = 'x' * (300 * 1024)
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', large_value, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'exceeds limit' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_not_json_serializable(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', {'key': {1, 2, 3}}, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'json' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_list(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.id', '123', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.name', 'test', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'memory.key', 'value', None
|
||||
)
|
||||
|
||||
scope_key = build_state_scope_key('conversation', event, binding, descriptor)
|
||||
|
||||
keys, has_more = await persistent_store.state_list(scope_key)
|
||||
assert len(keys) == 3
|
||||
assert has_more is False
|
||||
|
||||
keys_ext, _ = await persistent_store.state_list(scope_key, prefix='external.')
|
||||
assert len(keys_ext) == 2
|
||||
assert 'external.id' in keys_ext
|
||||
assert 'external.name' in keys_ext
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_delete(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['key'] == 'value'
|
||||
|
||||
scope_key = build_state_scope_key('conversation', event, binding, descriptor)
|
||||
deleted = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted is True
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert 'key' not in snapshot['conversation']
|
||||
|
||||
deleted_again = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted_again is False
|
||||
Reference in New Issue
Block a user