mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-20 20:44:21 +00:00
feat(agent-runner): add plugin runner host integration
This commit is contained in:
@@ -122,11 +122,9 @@ class FakeApp:
|
||||
return cmd_mgr
|
||||
|
||||
def _create_mock_skill_mgr(self):
|
||||
"""Mock SkillManager that returns no skill index addition by default."""
|
||||
"""Mock SkillManager with no loaded skills by default."""
|
||||
skill_mgr = Mock()
|
||||
skill_mgr.skills = {}
|
||||
skill_mgr.build_skill_aware_prompt_addition = Mock(return_value='')
|
||||
skill_mgr.get_skill_index = Mock(return_value=[])
|
||||
return skill_mgr
|
||||
|
||||
def _create_mock_pipeline_service(self):
|
||||
|
||||
@@ -18,14 +18,7 @@
|
||||
- **测试方式**: 需要 mock HTTP 响应或使用 fake LLM server
|
||||
- **状态**: 后续可补充 mock HTTP 测试
|
||||
|
||||
### 3. Agent Runner (`provider/runners/`)
|
||||
- **路径**: `src/langbot/pkg/provider/runners/`
|
||||
- **模块**: cozeapi, difysvapi, n8nsvapi, langflowapi, dashscopeapi, localagent, tboxapi
|
||||
- **排除原因**: 需要真实 Agent 平台(Coze、Dify、n8n 等)的 API 连接
|
||||
- **测试方式**: 需要 mock Agent 平台响应
|
||||
- **状态**: 后续可补充 mock 测试
|
||||
|
||||
### 4. 向量数据库 (`vector/vdbs/`)
|
||||
### 3. 向量数据库 (`vector/vdbs/`)
|
||||
- **路径**: `src/langbot/pkg/vector/vdbs/`
|
||||
- **模块**: chroma, milvus, pgvector, qdrant, seekdb
|
||||
- **排除原因**: 需要真实向量数据库实例运行
|
||||
@@ -42,7 +35,7 @@
|
||||
# 排除外部适配器后计算覆盖率
|
||||
pytest tests/unit_tests/ --cov=langbot.pkg \
|
||||
--cov-fail-under=0 \
|
||||
-o "cov_exclude_patterns=platform/sources/*,provider/modelmgr/requesters/*,provider/runners/*,vector/vdbs/*"
|
||||
-o "cov_exclude_patterns=platform/sources/*,provider/modelmgr/requesters/*,vector/vdbs/*"
|
||||
```
|
||||
|
||||
### 当前覆盖率(排除后)
|
||||
@@ -77,15 +70,11 @@ pytest tests/unit_tests/ --cov=langbot.pkg \
|
||||
- 使用 `httpx` mock 测试 API 响应解析
|
||||
- 测试重试逻辑、错误处理
|
||||
|
||||
2. **`provider/runners/`** (优先级:中)
|
||||
- Mock Agent 平台响应
|
||||
- 测试 session 管理、错误处理
|
||||
|
||||
3. **`platform/sources/`** (优先级:低)
|
||||
2. **`platform/sources/`** (优先级:低)
|
||||
- Mock 平台 webhook 事件
|
||||
- 测试消息解析、事件处理
|
||||
|
||||
4. **`vector/vdbs/`** (优先级:低)
|
||||
3. **`vector/vdbs/`** (优先级:低)
|
||||
- Mock 向量数据库操作
|
||||
- 测试 CRUD、查询逻辑
|
||||
|
||||
@@ -176,4 +165,4 @@ tests/unit_tests/
|
||||
| `core` | **28%** | 1289 | 🔄 需补充 app 启动 |
|
||||
| `persistence` | **24%** | 1099 | 🔄 需补充 mgr |
|
||||
|
||||
外部适配器测试需要 mock 环境或集成测试,不属于纯单元测试范畴。
|
||||
外部适配器测试需要 mock 环境或集成测试,不属于纯单元测试范畴。
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
"""Tests for agent runner subsystem."""
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Shared test fixtures for agent runner tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
|
||||
def make_resources(
|
||||
models: list[dict] | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
knowledge_bases: list[dict] | None = None,
|
||||
skills: list[dict] | None = None,
|
||||
storage: dict | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create a minimal AgentResources dict for testing.
|
||||
|
||||
Args:
|
||||
models: List of model dicts with 'model_id' key
|
||||
tools: List of tool dicts with 'tool_name' key
|
||||
knowledge_bases: List of KB dicts with 'kb_id' key
|
||||
skills: List of skill dicts with 'skill_name' key
|
||||
storage: Storage permissions dict
|
||||
Returns:
|
||||
AgentResources dict with all required fields
|
||||
"""
|
||||
return {
|
||||
'models': models or [],
|
||||
'tools': tools or [],
|
||||
'knowledge_bases': knowledge_bases or [],
|
||||
'skills': skills or [],
|
||||
'storage': storage or {'plugin_storage': False, 'workspace_storage': False},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
|
||||
def make_session(
|
||||
run_id: str = 'test-run-id',
|
||||
runner_id: str = 'plugin:test/test-runner/default',
|
||||
query_id: int | None = 1,
|
||||
plugin_identity: str = 'test/test-runner',
|
||||
resources: dict | None = None,
|
||||
conversation_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
available_apis: dict[str, bool] | None = None,
|
||||
state_policy: dict[str, typing.Any] | None = None,
|
||||
state_context: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create a minimal AgentRunSession dict for testing.
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier
|
||||
runner_id: Runner descriptor ID
|
||||
query_id: Host entry query ID
|
||||
plugin_identity: Plugin identifier (author/name)
|
||||
resources: AgentResources dict (uses make_resources() default if None)
|
||||
|
||||
Returns:
|
||||
AgentRunSession dict with run-scoped authorization snapshot
|
||||
"""
|
||||
import time
|
||||
now = int(time.time())
|
||||
res = resources if resources is not None else make_resources()
|
||||
apis = available_apis if available_apis is not None else {}
|
||||
policy = (
|
||||
state_policy
|
||||
if state_policy is not None
|
||||
else {'enable_state': True, 'state_scopes': ['conversation', 'actor']}
|
||||
)
|
||||
context = state_context if state_context is not None else {}
|
||||
|
||||
authorized_ids: dict[str, set[str]] = {
|
||||
'model': {m.get('model_id') for m in res.get('models', [])},
|
||||
'tool': {t.get('tool_name') for t in res.get('tools', [])},
|
||||
'knowledge_base': {kb.get('kb_id') for kb in res.get('knowledge_bases', [])},
|
||||
'skill': {s.get('skill_name') for s in res.get('skills', [])},
|
||||
}
|
||||
authorized_operations: dict[str, dict[str, set[str]]] = {
|
||||
'model': {
|
||||
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank'])
|
||||
for m in res.get('models', [])
|
||||
if m.get('model_id')
|
||||
},
|
||||
'tool': {
|
||||
t.get('tool_name'): set(t.get('operations') or ['detail', 'call'])
|
||||
for t in res.get('tools', [])
|
||||
if t.get('tool_name')
|
||||
},
|
||||
'knowledge_base': {
|
||||
kb.get('kb_id'): set(kb.get('operations') or ['list', 'retrieve'])
|
||||
for kb in res.get('knowledge_bases', [])
|
||||
if kb.get('kb_id')
|
||||
},
|
||||
'skill': {
|
||||
s.get('skill_name'): set(s.get('operations') or ['activate'])
|
||||
for s in res.get('skills', [])
|
||||
if s.get('skill_name')
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
'run_id': run_id,
|
||||
'runner_id': runner_id,
|
||||
'query_id': query_id,
|
||||
'plugin_identity': plugin_identity,
|
||||
'authorization': {
|
||||
'resources': res,
|
||||
'available_apis': apis,
|
||||
'conversation_id': conversation_id,
|
||||
'bot_id': bot_id,
|
||||
'workspace_id': workspace_id,
|
||||
'thread_id': thread_id,
|
||||
'state_policy': policy,
|
||||
'state_context': context,
|
||||
'authorized_ids': authorized_ids,
|
||||
'authorized_operations': authorized_operations,
|
||||
},
|
||||
'status': {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
"""Tests for ChatMessageHandler behavior with AgentRunOrchestrator.
|
||||
|
||||
Tests focus on:
|
||||
- Streaming mode behavior (single resp_message_id, pop/append pattern)
|
||||
- Non-streaming mode behavior (no pop)
|
||||
- Orchestrator invocation
|
||||
- Error handling for RunnerNotFoundError, RunnerExecutionError
|
||||
|
||||
Avoids circular imports by using proper import structure.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from langbot.pkg.agent.runner.errors import (
|
||||
RunnerNotFoundError,
|
||||
RunnerExecutionError,
|
||||
RunnerNotAuthorizedError,
|
||||
)
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
# Define mock classes in dependency order (no forward references needed)
|
||||
|
||||
class MockLauncherType:
|
||||
value = 'person'
|
||||
|
||||
|
||||
class MockConversation:
|
||||
def __init__(self):
|
||||
self.uuid = 'conv-uuid'
|
||||
self.messages = []
|
||||
|
||||
|
||||
class MockMessage:
|
||||
role = 'user'
|
||||
content = 'Hello'
|
||||
|
||||
|
||||
class MockAdapter:
|
||||
is_stream = False
|
||||
|
||||
async def is_stream_output_supported(self):
|
||||
return self.is_stream
|
||||
|
||||
async def create_message_card(self, resp_message_id, message_event):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession:
|
||||
launcher_type = MockLauncherType()
|
||||
launcher_id = 'user123'
|
||||
|
||||
def __init__(self):
|
||||
self.using_conversation = MockConversation()
|
||||
|
||||
|
||||
class MockQuery:
|
||||
"""Mock Query for testing."""
|
||||
def __init__(self):
|
||||
self.query_id = 1
|
||||
self.launcher_type = MockLauncherType()
|
||||
self.launcher_id = 'user123'
|
||||
self.sender_id = 'user123'
|
||||
self.bot_uuid = 'bot-uuid'
|
||||
self.pipeline_uuid = 'pipeline-uuid'
|
||||
self.pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
'runner_config': {},
|
||||
},
|
||||
'output': {
|
||||
'misc': {
|
||||
'exception-handling': 'show-hint',
|
||||
'failure-hint': 'Request failed.',
|
||||
},
|
||||
},
|
||||
}
|
||||
self.variables = {}
|
||||
self.session = MockSession()
|
||||
self.user_message = MockMessage()
|
||||
self.messages = []
|
||||
self.resp_messages = []
|
||||
self.resp_message_chain = None
|
||||
self.adapter = MockAdapter()
|
||||
self.message_event = MagicMock()
|
||||
self.message_chain = MagicMock()
|
||||
|
||||
|
||||
class MockMessageChunk:
|
||||
"""Mock MessageChunk for testing."""
|
||||
def __init__(self, content, resp_message_id=None):
|
||||
self.role = 'assistant'
|
||||
self.content = content
|
||||
self.resp_message_id = resp_message_id
|
||||
self.tool_calls = []
|
||||
self.is_final = False
|
||||
|
||||
def readable_str(self):
|
||||
return self.content
|
||||
|
||||
|
||||
class MockEventContext:
|
||||
"""Mock event context for testing."""
|
||||
def __init__(self, prevented=False, reply_message_chain=None, user_message_alter=None):
|
||||
self._prevented = prevented
|
||||
self.event = MagicMock()
|
||||
self.event.reply_message_chain = reply_message_chain
|
||||
self.event.user_message_alter = user_message_alter
|
||||
|
||||
def is_prevented_default(self):
|
||||
return self._prevented
|
||||
|
||||
|
||||
class MockAgentRunOrchestrator:
|
||||
"""Mock AgentRunOrchestrator for testing."""
|
||||
def __init__(self, chunks=None, error=None):
|
||||
self._chunks = chunks or []
|
||||
self._error = error
|
||||
|
||||
async def run_from_query(self, query):
|
||||
"""Async generator that yields chunks or raises error."""
|
||||
if self._error:
|
||||
raise self._error
|
||||
for chunk in self._chunks:
|
||||
yield chunk
|
||||
|
||||
async def try_claim_steering_from_query(self, query):
|
||||
return False
|
||||
|
||||
def resolve_runner_id_for_telemetry(self, query):
|
||||
return 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application for testing."""
|
||||
def __init__(self, orchestrator=None):
|
||||
self.agent_run_orchestrator = orchestrator or MockAgentRunOrchestrator()
|
||||
self.logger = MagicMock()
|
||||
self.logger.info = MagicMock()
|
||||
self.logger.debug = MagicMock()
|
||||
self.logger.warning = MagicMock()
|
||||
self.logger.error = MagicMock()
|
||||
|
||||
# Mock plugin_connector
|
||||
self.plugin_connector = MagicMock()
|
||||
self.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext())
|
||||
|
||||
# Mock telemetry
|
||||
self.telemetry = MagicMock()
|
||||
self.telemetry.start_send_task = AsyncMock()
|
||||
|
||||
# Mock survey
|
||||
self.survey = MagicMock()
|
||||
self.survey.trigger_event = AsyncMock()
|
||||
|
||||
# Mock model_mgr
|
||||
self.model_mgr = MagicMock()
|
||||
self.model_mgr.get_model_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
# Mock sess_mgr
|
||||
self.sess_mgr = MagicMock()
|
||||
self.sess_mgr.get_conversation = AsyncMock()
|
||||
|
||||
|
||||
class TestStreamingBehavior:
|
||||
"""Tests for streaming mode behavior."""
|
||||
|
||||
def test_single_resp_message_id_for_streaming(self):
|
||||
"""Streaming mode should use single resp_message_id for entire response."""
|
||||
# Simulate the streaming logic: resp_message_id created outside loop
|
||||
resp_message_id = uuid.uuid4()
|
||||
|
||||
chunks = ['Hello', ' World', '!']
|
||||
resp_messages = []
|
||||
|
||||
for chunk in chunks:
|
||||
result = MockMessageChunk(chunk)
|
||||
result.resp_message_id = str(resp_message_id)
|
||||
|
||||
# Pop old chunk (streaming behavior)
|
||||
if resp_messages:
|
||||
resp_messages.pop()
|
||||
resp_messages.append(result)
|
||||
|
||||
# All chunks should have same resp_message_id
|
||||
assert len(resp_messages) == 1 # Only last chunk remains after pop/append
|
||||
assert resp_messages[0].resp_message_id == str(resp_message_id)
|
||||
|
||||
def test_pop_before_append_in_streaming(self):
|
||||
"""Streaming mode should pop old chunk before appending new."""
|
||||
resp_message_id = uuid.uuid4()
|
||||
resp_messages = []
|
||||
|
||||
# First chunk - no pop
|
||||
chunk1 = MockMessageChunk('Hello')
|
||||
chunk1.resp_message_id = str(resp_message_id)
|
||||
resp_messages.append(chunk1)
|
||||
assert len(resp_messages) == 1
|
||||
|
||||
# Second chunk - pop first, then append
|
||||
if resp_messages:
|
||||
resp_messages.pop()
|
||||
chunk2 = MockMessageChunk('Hello World')
|
||||
chunk2.resp_message_id = str(resp_message_id)
|
||||
resp_messages.append(chunk2)
|
||||
assert len(resp_messages) == 1
|
||||
assert resp_messages[0].content == 'Hello World'
|
||||
|
||||
def test_non_streaming_no_pop(self):
|
||||
"""Non-streaming mode should NOT pop previous responses."""
|
||||
resp_messages = []
|
||||
|
||||
# First message
|
||||
msg1 = MockMessageChunk('Response 1')
|
||||
resp_messages.append(msg1)
|
||||
assert len(resp_messages) == 1
|
||||
|
||||
# Second message - should NOT pop in non-streaming
|
||||
msg2 = MockMessageChunk('Response 2')
|
||||
resp_messages.append(msg2)
|
||||
assert len(resp_messages) == 2
|
||||
|
||||
|
||||
class TestConfigMigrationInChatHandler:
|
||||
"""Tests for ConfigMigration usage in chat handler context."""
|
||||
|
||||
def test_resolve_runner_id_from_pipeline_config(self):
|
||||
"""Chat handler should use ConfigMigration to resolve runner ID."""
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolve_runner_id_from_old_format(self):
|
||||
"""ConfigMigration resolves old runner aliases for compatibility."""
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for orchestrator error handling."""
|
||||
|
||||
def test_runner_not_found_error_properties(self):
|
||||
"""RunnerNotFoundError should have runner_id property."""
|
||||
error = RunnerNotFoundError('plugin:notexist/unknown/default')
|
||||
assert error.runner_id == 'plugin:notexist/unknown/default'
|
||||
assert 'not found' in str(error)
|
||||
|
||||
def test_runner_execution_error_retryable(self):
|
||||
"""RunnerExecutionError should have retryable property."""
|
||||
error = RunnerExecutionError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'Upstream timeout',
|
||||
retryable=True,
|
||||
)
|
||||
assert error.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert error.retryable is True
|
||||
assert 'timeout' in str(error)
|
||||
|
||||
def test_runner_execution_error_not_retryable(self):
|
||||
"""RunnerExecutionError can be non-retryable."""
|
||||
error = RunnerExecutionError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'Configuration error',
|
||||
retryable=False,
|
||||
)
|
||||
assert error.retryable is False
|
||||
|
||||
def test_runner_not_authorized_error_properties(self):
|
||||
"""RunnerNotAuthorizedError should have bound_plugins property."""
|
||||
error = RunnerNotAuthorizedError(
|
||||
'plugin:langbot/local-agent/default',
|
||||
['langbot/dify-agent'],
|
||||
)
|
||||
assert error.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert error.bound_plugins == ['langbot/dify-agent']
|
||||
|
||||
|
||||
class TestChatHandlerImports:
|
||||
"""Test that chat handler can be imported without circular import."""
|
||||
|
||||
def test_import_chat_handler_module(self):
|
||||
"""Import chat handler module should work."""
|
||||
# This test verifies the import works without circular dependency
|
||||
from langbot.pkg.pipeline.process.handlers import chat
|
||||
assert chat.ChatMessageHandler is not None
|
||||
|
||||
def test_chat_handler_class_exists(self):
|
||||
"""ChatMessageHandler class should be defined."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
assert ChatMessageHandler.__name__ == 'ChatMessageHandler'
|
||||
|
||||
def test_chat_handler_has_handle_method(self):
|
||||
"""ChatMessageHandler should have async generator handle method."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
assert hasattr(ChatMessageHandler, 'handle')
|
||||
# handle returns AsyncGenerator, so check for async generator function
|
||||
import inspect
|
||||
assert inspect.isasyncgenfunction(ChatMessageHandler.handle)
|
||||
|
||||
|
||||
class TestChatHandlerAsyncBehavior:
|
||||
"""Real async tests for ChatMessageHandler.handle() with mocked orchestrator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_single_resp_message_id(self):
|
||||
"""Streaming mode: all chunks should have same resp_message_id."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
# Create chunks for streaming
|
||||
chunks = [
|
||||
MockMessageChunk('Hello'),
|
||||
MockMessageChunk('Hello World'),
|
||||
MockMessageChunk('Hello World!'),
|
||||
]
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(chunks=chunks)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
|
||||
# Mock event context to not prevent default
|
||||
event_ctx = MockEventContext(prevented=False)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = True # Enable streaming mode
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
# Mock event creation and StageProcessResult to bypass pydantic validation
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Verify single resp_message_id
|
||||
resp_ids = [msg.resp_message_id for msg in query.resp_messages if hasattr(msg, 'resp_message_id')]
|
||||
assert len(set(resp_ids)) == 1 # All same ID
|
||||
|
||||
# Verify pop/append pattern: only last chunk remains
|
||||
assert len(query.resp_messages) == 1
|
||||
assert query.resp_messages[0].content == 'Hello World!'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_no_pop(self):
|
||||
"""Non-streaming mode: all chunks should remain."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
chunks = [
|
||||
MockMessageChunk('Response 1'),
|
||||
MockMessageChunk('Response 2'),
|
||||
]
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(chunks=chunks)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = False # Disable streaming mode
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# No pop: all chunks should remain
|
||||
assert len(query.resp_messages) == 2
|
||||
assert query.resp_messages[0].content == 'Response 1'
|
||||
assert query.resp_messages[1].content == 'Response 2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turn_recreates_conversation_if_tool_resets_it(self):
|
||||
"""Agent turn bookkeeping should tolerate CREATE_NEW_CONVERSATION during runner execution."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
response = MockMessageChunk('Tool response')
|
||||
new_conversation = MockConversation()
|
||||
|
||||
class ResetConversationOrchestrator(MockAgentRunOrchestrator):
|
||||
async def run_from_query(self, query):
|
||||
query.session.using_conversation = None
|
||||
yield response
|
||||
|
||||
mock_ap = MockApplication(orchestrator=ResetConversationOrchestrator())
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
mock_ap.sess_mgr.get_conversation = AsyncMock(return_value=new_conversation)
|
||||
|
||||
query = MockQuery()
|
||||
query.adapter.is_stream = False
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
mock_ap.sess_mgr.get_conversation.assert_awaited_once()
|
||||
assert query.session.using_conversation is new_conversation
|
||||
assert new_conversation.messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_not_found_error(self):
|
||||
"""Handler should catch RunnerNotFoundError and return INTERRUPT."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerNotFoundError('plugin:notexist/unknown/default')
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Should return INTERRUPT with user_notice
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'not found' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_not_authorized_error(self):
|
||||
"""Handler should catch RunnerNotAuthorizedError and return INTERRUPT."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerNotAuthorizedError('plugin:langbot/local-agent/default', ['other/plugin'])
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'not authorized' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_execution_error_retryable(self):
|
||||
"""Handler should catch retryable RunnerExecutionError."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
orchestrator = MockAgentRunOrchestrator(
|
||||
error=RunnerExecutionError('plugin:langbot/local-agent/default', 'timeout', retryable=True)
|
||||
)
|
||||
mock_ap = MockApplication(orchestrator=orchestrator)
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(return_value=MockEventContext(prevented=False))
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(
|
||||
result_type=kwargs.get('result_type'),
|
||||
user_notice=kwargs.get('user_notice'),
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert 'temporarily unavailable' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevented_default_with_reply(self):
|
||||
"""When event prevented default with reply, use reply message."""
|
||||
from langbot.pkg.pipeline.process.handlers.chat import ChatMessageHandler
|
||||
from langbot.pkg.pipeline import entities
|
||||
|
||||
# Mock reply message chain
|
||||
reply_chain = MockMessageChunk('Reply from plugin')
|
||||
|
||||
mock_ap = MockApplication()
|
||||
mock_ap.plugin_connector.emit_event = AsyncMock(
|
||||
return_value=MockEventContext(prevented=True, reply_message_chain=reply_chain)
|
||||
)
|
||||
|
||||
query = MockQuery()
|
||||
|
||||
handler = ChatMessageHandler(mock_ap)
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.return_value = MagicMock()
|
||||
|
||||
def make_result(*args, **kwargs):
|
||||
return MagicMock(result_type=kwargs.get('result_type', entities.ResultType.CONTINUE))
|
||||
|
||||
with patch('langbot.pkg.pipeline.process.handlers.chat.events') as mock_events_module, \
|
||||
patch('langbot.pkg.pipeline.entities.StageProcessResult', side_effect=make_result):
|
||||
mock_events_module.PersonNormalMessageReceived = mock_event
|
||||
mock_events_module.GroupNormalMessageReceived = mock_event
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Should return CONTINUE with reply message
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert len(query.resp_messages) == 1
|
||||
@@ -0,0 +1,257 @@
|
||||
"""Tests for current AgentRunner config helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
class TestResolveRunnerId:
|
||||
"""Tests for ConfigMigration.resolve_runner_id."""
|
||||
|
||||
def test_resolve_current_runner_id(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolves_old_runner_field(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner_id = ConfigMigration.resolve_runner_id(pipeline_config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_resolves_deerflow_and_weknora_legacy_runner_fields(self):
|
||||
assert (
|
||||
ConfigMigration.resolve_runner_id(
|
||||
{
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'deerflow-api',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
== 'plugin:langbot/deerflow-agent/default'
|
||||
)
|
||||
assert (
|
||||
ConfigMigration.resolve_runner_id(
|
||||
{
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'weknora-api',
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
== 'plugin:langbot/weknora-agent/default'
|
||||
)
|
||||
|
||||
def test_resolve_no_runner_config(self):
|
||||
runner_id = ConfigMigration.resolve_runner_id({})
|
||||
assert runner_id is None
|
||||
|
||||
|
||||
class TestResolveRunnerConfig:
|
||||
"""Tests for ConfigMigration.resolve_runner_config."""
|
||||
|
||||
def test_resolve_current_config(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': 'uuid-123',
|
||||
'custom_option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': 'uuid-123', 'custom_option': 10}
|
||||
|
||||
def test_reads_old_runner_block(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'local-agent': {
|
||||
'model': 'uuid-123',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': {'primary': 'uuid-123', 'fallbacks': []}}
|
||||
|
||||
def test_reads_deerflow_and_weknora_legacy_runner_blocks(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'deerflow-api': {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
},
|
||||
'weknora-api': {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deerflow_config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/deerflow-agent/default',
|
||||
)
|
||||
weknora_config = ConfigMigration.resolve_runner_config(
|
||||
pipeline_config,
|
||||
'plugin:langbot/weknora-agent/default',
|
||||
)
|
||||
|
||||
assert deerflow_config == {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
}
|
||||
assert weknora_config == {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
}
|
||||
|
||||
def test_resolve_no_config(self):
|
||||
config = ConfigMigration.resolve_runner_config(
|
||||
{},
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestGetExpireTime:
|
||||
"""Tests for ConfigMigration.get_expire_time."""
|
||||
|
||||
def test_get_expire_time_zero(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'expire-time': 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expire_time = ConfigMigration.get_expire_time(pipeline_config)
|
||||
assert expire_time == 0
|
||||
|
||||
def test_get_expire_time_positive(self):
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'expire-time': 3600,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expire_time = ConfigMigration.get_expire_time(pipeline_config)
|
||||
assert expire_time == 3600
|
||||
|
||||
def test_get_expire_time_default(self):
|
||||
expire_time = ConfigMigration.get_expire_time({})
|
||||
assert expire_time == 0
|
||||
|
||||
|
||||
class TestNormalizePipelineConfig:
|
||||
"""Tests for ConfigMigration.migrate_pipeline_config."""
|
||||
|
||||
def test_normalizes_current_containers(self):
|
||||
config = {'ai': {}}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated == {'ai': {'runner': {}, 'runner_config': {}}}
|
||||
|
||||
def test_preserves_current_config(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:test/my-runner/default'},
|
||||
'runner_config': {
|
||||
'plugin:test/my-runner/default': {'custom-option': 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:test/my-runner/default'
|
||||
assert migrated['ai']['runner_config']['plugin:test/my-runner/default']['custom-option'] == 20
|
||||
|
||||
def test_migrates_old_runner_blocks(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': 'old-model', 'knowledge-base': 'kb_1'},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/local-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'local-agent' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default'] == {
|
||||
'model': {'primary': 'old-model', 'fallbacks': []},
|
||||
'knowledge-bases': ['kb_1'],
|
||||
}
|
||||
|
||||
def test_migrates_deerflow_legacy_runner_block(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'deerflow-api'},
|
||||
'deerflow-api': {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/deerflow-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'deerflow-api' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/deerflow-agent/default'] == {
|
||||
'api-base': 'http://127.0.0.1:2026',
|
||||
'assistant-id': 'lead_agent',
|
||||
}
|
||||
|
||||
def test_migrates_weknora_legacy_runner_block(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'weknora-api'},
|
||||
'weknora-api': {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/weknora-agent/default'
|
||||
assert 'runner' not in migrated['ai']['runner']
|
||||
assert 'weknora-api' not in migrated['ai']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/weknora-agent/default'] == {
|
||||
'base-url': 'http://localhost:8080/api/v1',
|
||||
'app-type': 'agent',
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests for persisted AgentRunner config shape."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
|
||||
class TestMigratePipelineConfig:
|
||||
"""Tests for ConfigMigration.migrate_pipeline_config."""
|
||||
|
||||
def test_current_format_config_stays_unchanged(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
'expire-time': 0,
|
||||
},
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': {'primary': '', 'fallbacks': []},
|
||||
'custom-option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/local-agent/default'
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['custom-option'] == 10
|
||||
|
||||
def test_old_runner_field_is_mapped(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {
|
||||
'runner': 'local-agent',
|
||||
'expire-time': 3600,
|
||||
},
|
||||
'local-agent': {
|
||||
'model': 'old-model',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
|
||||
assert migrated['ai']['runner'] == {
|
||||
'expire-time': 3600,
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
}
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default'] == {
|
||||
'model': {'primary': 'old-model', 'fallbacks': []},
|
||||
}
|
||||
assert 'local-agent' not in migrated['ai']
|
||||
|
||||
def test_empty_config_is_unchanged(self):
|
||||
config = {}
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
assert migrated == {}
|
||||
|
||||
def test_config_without_ai_section_is_unchanged(self):
|
||||
config = {'trigger': {}}
|
||||
migrated = ConfigMigration.migrate_pipeline_config(config)
|
||||
assert migrated == {'trigger': {}}
|
||||
|
||||
|
||||
class TestDefaultPipelineConfig:
|
||||
"""Tests for default-pipeline-config.json format."""
|
||||
|
||||
def test_default_config_is_current_format(self):
|
||||
from langbot.pkg.utils import paths as path_utils
|
||||
|
||||
template_path = path_utils.get_resource_path('templates/default-pipeline-config.json')
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert 'ai' in config
|
||||
assert 'runner' in config['ai']
|
||||
assert 'id' in config['ai']['runner']
|
||||
assert config['ai']['runner']['id'] == ''
|
||||
assert 'runner_config' in config['ai']
|
||||
assert config['ai']['runner_config'] == {}
|
||||
assert 'local-agent' not in config['ai']
|
||||
|
||||
|
||||
class TestResolveRunnerId:
|
||||
"""Tests for current runner id resolution."""
|
||||
|
||||
def test_resolve_current_id(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:test/my-runner/default'},
|
||||
},
|
||||
}
|
||||
runner_id = ConfigMigration.resolve_runner_id(config)
|
||||
assert runner_id == 'plugin:test/my-runner/default'
|
||||
|
||||
def test_old_runner_field_is_mapped(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
},
|
||||
}
|
||||
runner_id = ConfigMigration.resolve_runner_id(config)
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class TestResolveRunnerConfig:
|
||||
"""Tests for runtime runner config resolution."""
|
||||
|
||||
def test_resolve_current_config(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {'custom-option': 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config['custom-option'] == 20
|
||||
|
||||
def test_old_runner_block_is_read(self):
|
||||
config = {
|
||||
'ai': {
|
||||
'local-agent': {'custom-option': 20},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config == {'custom-option': 20}
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for Query entry adapter params packaging."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
|
||||
|
||||
class TestBuildParams:
|
||||
"""Tests for QueryEntryAdapter.build_params filtering."""
|
||||
|
||||
def test_params_empty_when_no_variables(self):
|
||||
query = type('Query', (), {'variables': None})()
|
||||
assert QueryEntryAdapter.build_params(query) == {}
|
||||
|
||||
def test_params_filters_underscore_prefix(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'_internal_var': 'should_be_excluded',
|
||||
'_pipeline_bound_plugins': ['a/b'],
|
||||
'_monitoring_bot_name': 'Bot',
|
||||
'public_var': 'should_be_included',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert '_internal_var' not in params
|
||||
assert '_pipeline_bound_plugins' not in params
|
||||
assert '_monitoring_bot_name' not in params
|
||||
assert params['public_var'] == 'should_be_included'
|
||||
|
||||
def test_params_filters_sensitive_naming(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'api_key': 'secret123',
|
||||
'API_KEY': 'secret456',
|
||||
'token': 'tok123',
|
||||
'secret': 'sec123',
|
||||
'password': 'pass123',
|
||||
'credential': 'cred123',
|
||||
'user_api_key': 'should_be_excluded',
|
||||
'user_secret_key': 'should_be_excluded',
|
||||
'my_token_value': 'should_be_excluded',
|
||||
'user_password_hash': 'should_be_excluded',
|
||||
'public_name': 'should_be_included',
|
||||
'safe_value': 'should_be_included',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'api_key' not in params
|
||||
assert 'API_KEY' not in params
|
||||
assert 'token' not in params
|
||||
assert 'secret' not in params
|
||||
assert 'password' not in params
|
||||
assert 'credential' not in params
|
||||
assert 'user_api_key' not in params
|
||||
assert 'user_secret_key' not in params
|
||||
assert 'my_token_value' not in params
|
||||
assert 'user_password_hash' not in params
|
||||
assert 'public_name' in params
|
||||
assert 'safe_value' in params
|
||||
|
||||
def test_params_keeps_common_public_vars(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'launcher_type': 'telegram',
|
||||
'launcher_id': 'group_123',
|
||||
'sender_id': 'user_001',
|
||||
'session_id': 'sess_abc',
|
||||
'msg_create_time': 1234567890,
|
||||
'group_name': 'Tech Group',
|
||||
'sender_name': 'John',
|
||||
'user_message_text': 'Hello world',
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert params['launcher_type'] == 'telegram'
|
||||
assert params['launcher_id'] == 'group_123'
|
||||
assert params['sender_id'] == 'user_001'
|
||||
assert params['session_id'] == 'sess_abc'
|
||||
assert params['msg_create_time'] == 1234567890
|
||||
assert params['group_name'] == 'Tech Group'
|
||||
assert params['sender_name'] == 'John'
|
||||
assert params['user_message_text'] == 'Hello world'
|
||||
|
||||
def test_params_filters_non_json_serializable(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'string_value': 'hello',
|
||||
'int_value': 42,
|
||||
'float_value': 3.14,
|
||||
'bool_value': True,
|
||||
'null_value': None,
|
||||
'list_value': ['a', 'b', 'c'],
|
||||
'dict_value': {'nested': 'value'},
|
||||
'custom_object': CustomObject(),
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'string_value' in params
|
||||
assert 'int_value' in params
|
||||
assert 'float_value' in params
|
||||
assert 'bool_value' in params
|
||||
assert 'null_value' in params
|
||||
assert 'list_value' in params
|
||||
assert 'dict_value' in params
|
||||
assert 'custom_object' not in params
|
||||
|
||||
def test_params_filters_nested_non_serializable(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
query = type('Query', (), {
|
||||
'variables': {
|
||||
'nested_list_with_bad': ['a', CustomObject(), 'c'],
|
||||
'nested_dict_with_bad': {'good': 'value', 'bad': CustomObject()},
|
||||
'good_nested_list': ['a', ['b', 'c']],
|
||||
'good_nested_dict': {'outer': {'inner': 'value'}},
|
||||
},
|
||||
})()
|
||||
|
||||
params = QueryEntryAdapter.build_params(query)
|
||||
assert 'nested_list_with_bad' not in params
|
||||
assert 'nested_dict_with_bad' not in params
|
||||
assert 'good_nested_list' in params
|
||||
assert 'good_nested_dict' in params
|
||||
|
||||
def test_is_json_serializable_primitives_and_collections(self):
|
||||
assert QueryEntryAdapter.is_json_serializable(None) is True
|
||||
assert QueryEntryAdapter.is_json_serializable('string') is True
|
||||
assert QueryEntryAdapter.is_json_serializable(42) is True
|
||||
assert QueryEntryAdapter.is_json_serializable(['a', 'b']) is True
|
||||
assert QueryEntryAdapter.is_json_serializable({'key': 'value'}) is True
|
||||
assert QueryEntryAdapter.is_json_serializable((1, 2, 3)) is True
|
||||
|
||||
def test_is_json_serializable_rejects_sets_and_objects(self):
|
||||
class CustomObject:
|
||||
pass
|
||||
|
||||
assert QueryEntryAdapter.is_json_serializable(CustomObject()) is False
|
||||
assert QueryEntryAdapter.is_json_serializable({1, 2, 3}) is False
|
||||
assert QueryEntryAdapter.is_json_serializable([1, {2, 3}]) is False
|
||||
assert QueryEntryAdapter.is_json_serializable({'key': {1, 2}}) is False
|
||||
|
||||
|
||||
class TestBuildAdapterContext:
|
||||
"""Tests for QueryEntryAdapter.build_adapter_context."""
|
||||
|
||||
def test_adapter_context_does_not_push_prompt(self):
|
||||
query = type('Query', (), {
|
||||
'variables': {},
|
||||
'query_id': 123,
|
||||
'prompt': object(),
|
||||
})()
|
||||
|
||||
context = QueryEntryAdapter.build_adapter_context(query, binding=None)
|
||||
|
||||
assert context == {'params': {}, 'query_id': 123}
|
||||
@@ -0,0 +1,361 @@
|
||||
"""Tests for ContextAccess.state determination in AgentRunContextBuilder.
|
||||
|
||||
Tests focus on:
|
||||
- Event-first mode: state=True when enable_state=True and state_scopes non-empty
|
||||
- Event-first mode: state=False when enable_state=False
|
||||
- Legacy Query mode: state=False (no persistent state API)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application for testing."""
|
||||
def __init__(self):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock()
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id='plugin:test/runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
permissions=permissions
|
||||
if permissions is not None
|
||||
else {
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestContextAccessStateDetermination:
|
||||
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
return AgentEventEnvelope(
|
||||
event_id='evt_001',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id='conv_001',
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_descriptor(self):
|
||||
"""Create mock runner descriptor."""
|
||||
return make_descriptor()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=True when enable_state=True and state_scopes non-empty."""
|
||||
# Create binding with state enabled and non-empty scopes
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call to _build_context_access
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=True based on binding.state_policy
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_false_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=False."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_empty_scopes_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when enable_state=True but state_scopes empty."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=[], # Empty scopes - state not available
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Verify state=False (empty scopes means state not available)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_binding_sets_state_false(self, mock_app, mock_event, mock_descriptor):
|
||||
"""ContextAccess.state=False when no binding is provided."""
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call without binding
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding=None)
|
||||
|
||||
# Verify state=False (no binding = no state policy = state disabled)
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_scope_available_without_conversation(self, mock_app, mock_descriptor):
|
||||
"""State API with runner scope is available even without conversation_id."""
|
||||
mock_event = AgentEventEnvelope(
|
||||
event_id='evt_002',
|
||||
event_type='message.received',
|
||||
event_time=1234567890,
|
||||
source='test',
|
||||
bot_id='bot_001',
|
||||
workspace_id='ws_001',
|
||||
conversation_id=None, # No conversation
|
||||
thread_id=None,
|
||||
actor=ActorContext(actor_type='user', actor_id='user_001'),
|
||||
subject=None,
|
||||
input=AgentInput(text='hello', contents=[], attachments=[]),
|
||||
delivery=DeliveryContext(surface='test', supports_streaming=True),
|
||||
)
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_002',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='workspace', scope_id='ws_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['runner'], # Runner scope doesn't need conversation_id
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True because runner scope is enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_scopes_all_available(self, mock_app, mock_event, mock_descriptor):
|
||||
"""State API with multiple scopes enabled."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_003',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# State should be True with all scopes enabled
|
||||
assert context_access['available_apis']['state'] is True
|
||||
|
||||
|
||||
class TestStatePolicyFromBinding:
|
||||
"""Tests for state_policy extraction from binding."""
|
||||
|
||||
def test_state_policy_structure(self):
|
||||
"""State policy has correct structure."""
|
||||
policy = StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation', 'actor', 'subject', 'runner'],
|
||||
)
|
||||
|
||||
assert policy.enable_state is True
|
||||
assert len(policy.state_scopes) == 4
|
||||
assert 'conversation' in policy.state_scopes
|
||||
|
||||
def test_state_policy_disabled(self):
|
||||
"""State policy can be disabled."""
|
||||
policy = StatePolicy(
|
||||
enable_state=False,
|
||||
state_scopes=[],
|
||||
)
|
||||
|
||||
assert policy.enable_state is False
|
||||
assert len(policy.state_scopes) == 0
|
||||
|
||||
|
||||
class TestBindingWithStatePolicy:
|
||||
"""Tests for binding with state_policy."""
|
||||
|
||||
def test_binding_contains_state_policy(self):
|
||||
"""Binding contains state_policy field."""
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(
|
||||
enable_state=True,
|
||||
state_scopes=['conversation'],
|
||||
),
|
||||
)
|
||||
|
||||
assert binding.state_policy is not None
|
||||
assert binding.state_policy.enable_state is True
|
||||
|
||||
|
||||
class TestContextAccessOtherAPIs:
|
||||
"""Tests for other available_apis fields based on run scope."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
return MockApplication()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_apis_enabled_with_conversation(self, mock_app):
|
||||
"""History APIs are available when the run has a conversation scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['prompt_get'] is False
|
||||
assert context_access['available_apis']['history_page'] is True
|
||||
assert context_access['available_apis']['history_search'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_apis_enabled_by_default(self, mock_app):
|
||||
"""Event APIs are available based on current run scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_required_apis_disabled_without_conversation(self, mock_app):
|
||||
"""Conversation-scoped APIs are disabled when the run has no conversation."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = None
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manifest_permissions_disable_context_apis(self, mock_app):
|
||||
"""Pull APIs are disabled when manifest permissions omit them."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor(permissions={})
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is False
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['storage'] is False
|
||||
@@ -0,0 +1,428 @@
|
||||
"""Test that LangBot context builder output validates against SDK AgentRunContext."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# SDK imports for validation
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context import AgentRunContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import AgentEventContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context_access import ContextAccess
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
|
||||
# LangBot imports
|
||||
from langbot.pkg.agent.runner.context_builder import (
|
||||
AgentRunContextBuilder,
|
||||
AgentResources as BuilderResources,
|
||||
)
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class TestContextValidation:
|
||||
"""Test that context builder output validates against SDK AgentRunContext."""
|
||||
|
||||
def _make_mock_app(self):
|
||||
"""Create a mock application."""
|
||||
mock_app = MagicMock(spec=app.Application)
|
||||
mock_app.ver_mgr = MagicMock()
|
||||
mock_app.ver_mgr.get_current_version = MagicMock(return_value="1.0.0")
|
||||
mock_app.persistence_mgr = MagicMock()
|
||||
mock_app.persistence_mgr.get_db_engine = MagicMock()
|
||||
mock_app.logger = MagicMock()
|
||||
return mock_app
|
||||
|
||||
def _make_event_envelope(self) -> AgentEventEnvelope:
|
||||
"""Create a test event envelope."""
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput as EventInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
return AgentEventEnvelope(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
event_time=1700000000,
|
||||
source="platform",
|
||||
source_event_type="platform.message",
|
||||
bot_id="bot_1",
|
||||
workspace_id="workspace_1",
|
||||
conversation_id="conv_1",
|
||||
thread_id=None,
|
||||
actor=ActorContext(
|
||||
actor_type="user",
|
||||
actor_id="user_1",
|
||||
actor_name="Test User",
|
||||
),
|
||||
subject=None,
|
||||
input=EventInput(text="Hello world"),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
data={"platform_event_id": "source_evt_1"},
|
||||
)
|
||||
|
||||
def _make_binding(self) -> AgentBinding:
|
||||
"""Create a test binding."""
|
||||
return AgentBinding(
|
||||
binding_id="binding_1",
|
||||
scope=BindingScope(scope_type="agent", scope_id="pipeline_1"),
|
||||
event_types=["message.received"],
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
runner_config={"timeout": 300},
|
||||
agent_id="pipeline_1",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def _make_resources(self) -> BuilderResources:
|
||||
"""Create test resources."""
|
||||
return {
|
||||
'models': [],
|
||||
'tools': [],
|
||||
'knowledge_bases': [],
|
||||
'skills': [],
|
||||
'files': [],
|
||||
'storage': {'plugin_storage': True, 'workspace_storage': True},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
def _make_descriptor(self):
|
||||
"""Create a mock runner descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id="plugin:test/plugin/runner",
|
||||
source="plugin",
|
||||
label={"en_US": "Test Runner"},
|
||||
plugin_author="test",
|
||||
plugin_name="plugin",
|
||||
runner_name="runner",
|
||||
permissions={
|
||||
"history": ["page", "search"],
|
||||
"events": ["get", "page"],
|
||||
"storage": ["plugin", "workspace"],
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_validates(self):
|
||||
"""Test that build_context_from_event output validates against SDK AgentRunContext."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
# Build context
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Validate it can be parsed by SDK AgentRunContext
|
||||
# This will raise ValidationError if invalid
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
|
||||
# Verify required fields
|
||||
assert validated.run_id is not None
|
||||
assert validated.event is not None
|
||||
assert isinstance(validated.event, AgentEventContext)
|
||||
assert validated.delivery is not None
|
||||
assert isinstance(validated.delivery, DeliveryContext)
|
||||
assert validated.context is not None
|
||||
assert isinstance(validated.context, ContextAccess)
|
||||
assert validated.input is not None
|
||||
assert isinstance(validated.input, AgentInput)
|
||||
assert validated.resources is not None
|
||||
assert isinstance(validated.resources, AgentResources)
|
||||
assert validated.runtime is not None
|
||||
assert isinstance(validated.runtime, AgentRuntimeContext)
|
||||
assert "protocol_version" not in validated.runtime.model_dump()
|
||||
assert "sdk_protocol_version" not in validated.runtime.model_dump()
|
||||
assert "sdk_protocol_version" not in context_dict["runtime"]
|
||||
|
||||
# Verify event context
|
||||
assert validated.event.event_id == "evt_1"
|
||||
assert validated.event.event_type == "message.received"
|
||||
assert validated.event.source == "platform"
|
||||
assert validated.event.source_event_type == "platform.message"
|
||||
assert validated.event.data == {"platform_event_id": "source_evt_1"}
|
||||
|
||||
# Verify conversation context uses SDK field names
|
||||
assert validated.conversation is not None
|
||||
assert validated.conversation.bot_id == "bot_1"
|
||||
assert validated.conversation.workspace_id == "workspace_1"
|
||||
|
||||
# Verify delivery context
|
||||
assert validated.delivery.surface == "test"
|
||||
|
||||
# Verify input
|
||||
assert validated.input.text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_populates_model_context_window(self):
|
||||
"""Runtime metadata should expose the selected LLM model context window."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=128000),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'rerank-model',
|
||||
'model_type': 'rerank',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['rerank'],
|
||||
},
|
||||
{
|
||||
'model_id': 'llm-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
assert context_dict['runtime']['metadata']['model_context_window_tokens'] == 128000
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('llm-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_context_window_uses_primary_llm_only(self):
|
||||
"""Fallback model windows should not replace missing primary model metadata."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=None),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'primary-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
{
|
||||
'model_id': 'fallback-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
|
||||
assert await builder._build_model_context_window_tokens(resources) is None
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('primary-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_preserves_subject_data_for_non_message_events(self):
|
||||
"""Non-message EBA events keep subject.data instead of relying on message text."""
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext, SubjectContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput as EventInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
event = AgentEventEnvelope(
|
||||
event_id="evt_recall_1",
|
||||
event_type="message.recalled",
|
||||
event_time=1700000001,
|
||||
source="platform",
|
||||
source_event_type="platform.message.recall",
|
||||
bot_id="bot_1",
|
||||
workspace_id="workspace_1",
|
||||
conversation_id="conv_1",
|
||||
actor=ActorContext(actor_type="user", actor_id="user_1"),
|
||||
subject=SubjectContext(
|
||||
subject_type="message",
|
||||
subject_id="message_1",
|
||||
data={"recalled_message_id": "message_1", "reason": "user_recall"},
|
||||
),
|
||||
input=EventInput(text=None),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
data={"source_event_id": "source_recall_1"},
|
||||
)
|
||||
binding = self._make_binding()
|
||||
binding.event_types = ["message.recalled"]
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
|
||||
assert validated.event.event_type == "message.recalled"
|
||||
assert validated.input.text is None
|
||||
assert validated.subject is not None
|
||||
assert validated.subject.subject_type == "message"
|
||||
assert validated.subject.subject_id == "message_1"
|
||||
assert validated.subject.data == {"recalled_message_id": "message_1", "reason": "user_recall"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_has_no_legacy_top_level_fields(self):
|
||||
"""Test that build_context_from_event does NOT have top-level messages/prompt/params."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# Protocol v1 does NOT have these as core fields
|
||||
assert 'messages' not in context_dict, "messages should not be top-level in Protocol v1"
|
||||
assert 'prompt' not in context_dict, "prompt should not be top-level in Protocol v1"
|
||||
assert 'params' not in context_dict, "params should not be top-level in Protocol v1"
|
||||
|
||||
# Protocol v1 DOES have these
|
||||
assert 'delivery' in context_dict, "delivery is REQUIRED in Protocol v1"
|
||||
assert 'context' in context_dict, "context (ContextAccess) is REQUIRED in Protocol v1"
|
||||
assert 'bootstrap' not in context_dict, "Host must not inline bootstrap/history windows"
|
||||
assert 'adapter' in context_dict, "adapter should exist"
|
||||
assert 'metadata' in context_dict, "metadata should exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_event_is_not_none(self):
|
||||
"""Test that event field is NOT None in Protocol v1."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# event is REQUIRED in Protocol v1
|
||||
assert context_dict.get('event') is not None, "event is REQUIRED for Protocol v1"
|
||||
|
||||
# Validate
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
assert validated.event is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_delivery_is_not_none(self):
|
||||
"""Test that delivery field is NOT None in Protocol v1."""
|
||||
mock_app = self._make_mock_app()
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
# Mock persistent state store to return empty state snapshot
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
# delivery is REQUIRED in Protocol v1
|
||||
assert context_dict.get('delivery') is not None, "delivery is REQUIRED for Protocol v1"
|
||||
|
||||
# Validate
|
||||
validated = AgentRunContext.model_validate(context_dict)
|
||||
assert validated.delivery is not None
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Tests for event-first Protocol v1 entities and Query entry adapter.
|
||||
|
||||
Tests cover:
|
||||
1. Query -> AgentEventEnvelope conversion
|
||||
2. Current config -> AgentConfig projection and single-binding resolution
|
||||
3. AgentRunContext not inlining full history by default
|
||||
4. LangBot Host not defining context-window controls
|
||||
5. Event-first run() entry point
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Import SDK entities
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import (
|
||||
AgentEventContext,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.trigger import AgentTrigger
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context import AgentRunContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.result import (
|
||||
AgentRunResult,
|
||||
)
|
||||
|
||||
# Import LangBot host models
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
from langbot.pkg.agent.runner.binding_resolver import (
|
||||
AgentBindingResolver,
|
||||
AgentBindingResolutionError,
|
||||
)
|
||||
|
||||
|
||||
class TestQueryToEventEnvelope:
|
||||
"""Test Query -> AgentEventEnvelope conversion."""
|
||||
|
||||
def test_query_to_event_basic_fields(self, mock_query):
|
||||
"""Test basic field conversion from Query to Event envelope."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.event_type == "message.received"
|
||||
assert event.source == "host_adapter"
|
||||
assert event.bot_id == mock_query.bot_uuid
|
||||
assert event.actor is not None
|
||||
assert event.actor.actor_type == "user"
|
||||
|
||||
def test_query_to_event_input(self, mock_query):
|
||||
"""Test input conversion from Query."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.input is not None
|
||||
assert event.input.text == "Hello world"
|
||||
assert "message_chain" not in event.input.model_dump()
|
||||
|
||||
def test_query_to_event_conversation(self, mock_query):
|
||||
"""Test conversation context extraction."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "conv-uuid-123"
|
||||
|
||||
def test_query_to_event_prefers_variable_conversation_id_when_conversation_uuid_missing(self, mock_query):
|
||||
"""Pipeline variables can provide the conversation identity for state scope."""
|
||||
mock_query.session.using_conversation.uuid = None
|
||||
mock_query.variables["conversation_id"] = "conv-from-vars"
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "conv-from-vars"
|
||||
|
||||
def test_query_to_event_falls_back_to_launcher_session_for_state_scope(self, mock_query):
|
||||
"""Debug Chat and legacy pipeline runs may not have a conversation UUID."""
|
||||
mock_query.session.using_conversation.uuid = None
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.conversation_id == "person_launcher-123"
|
||||
|
||||
def test_query_to_event_delivery_context(self, mock_query):
|
||||
"""Test delivery context extraction."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.delivery is not None
|
||||
assert event.delivery.surface == "platform"
|
||||
assert isinstance(event.delivery.supports_streaming, bool)
|
||||
|
||||
def test_query_to_event_preserves_source_event_data(self, mock_query):
|
||||
"""Test source event metadata survives the adapter boundary."""
|
||||
source_event = Mock()
|
||||
source_event.type = "platform.message.created"
|
||||
source_event.time = 1700000000
|
||||
source_event.sender = None
|
||||
source_event.model_dump = Mock(return_value={
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
"source_platform_object": {"large": "payload"},
|
||||
})
|
||||
mock_query.message_event = source_event
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.source_event_type == "platform.message.created"
|
||||
assert event.event_time == 1700000000
|
||||
assert event.data == {
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_keeps_large_payloads_out_of_event_data(self, mock_query):
|
||||
"""Large or nested platform payloads should not be duplicated into event.data."""
|
||||
source_event = Mock()
|
||||
source_event.type = "platform.message.created"
|
||||
source_event.time = 1700000000
|
||||
source_event.sender = None
|
||||
source_event.model_dump = Mock(return_value={
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
"message_chain": [{"type": "Image", "base64": "data:image/png;base64," + ("a" * 1024)}],
|
||||
"raw_text": "x" * 1024,
|
||||
"source_platform_object": {"large": "payload"},
|
||||
})
|
||||
mock_query.message_event = source_event
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.data == {
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_handles_missing_message_chain(self, mock_query):
|
||||
"""Test delivery context building when Query has no message_chain."""
|
||||
delattr(mock_query, "message_chain")
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.delivery.reply_target == {"message_id": None}
|
||||
|
||||
def test_query_to_event_scopes_pipeline_local_event_ids(self, mock_query):
|
||||
"""Pipeline-local message IDs must not become global audit IDs."""
|
||||
first = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
mock_query.launcher_id = "launcher-456"
|
||||
second = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert first.event_id.startswith("host:")
|
||||
assert first.event_id != "789"
|
||||
assert second.event_id != first.event_id
|
||||
|
||||
|
||||
class TestQueryConfigToAgentConfig:
|
||||
"""Test current config projection and single-Agent binding resolution."""
|
||||
|
||||
def test_config_to_agent_config_runner_id(self, mock_query):
|
||||
"""Test AgentConfig runner_id extraction."""
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:author/plugin/runner"
|
||||
)
|
||||
|
||||
assert agent_config.runner_id == "plugin:author/plugin/runner"
|
||||
|
||||
def test_config_to_agent_config_uses_legacy_runner_config_migration(self, mock_query):
|
||||
"""Temporary query adapter must share the normal runner config resolver."""
|
||||
mock_query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": "model-primary",
|
||||
"knowledge-base": "kb-001",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query,
|
||||
"plugin:langbot/local-agent/default",
|
||||
)
|
||||
|
||||
assert agent_config.runner_config["model"] == {
|
||||
"primary": "model-primary",
|
||||
"fallbacks": [],
|
||||
}
|
||||
assert agent_config.runner_config["knowledge-bases"] == ["kb-001"]
|
||||
|
||||
def test_resolver_projects_agent_scope(self, mock_query):
|
||||
"""Test binding scope projection through the resolver."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
binding = AgentBindingResolver().resolve_one(event, [agent_config])
|
||||
|
||||
assert binding.scope.scope_type == "agent"
|
||||
assert binding.scope.scope_id == mock_query.pipeline_uuid
|
||||
assert binding.agent_id == mock_query.pipeline_uuid
|
||||
|
||||
def test_resolver_rejects_multiple_matching_agents(self, mock_query):
|
||||
"""Event dispatch is single-Agent in v1."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
first = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
second = first.model_copy(update={"agent_id": "agent_2"})
|
||||
|
||||
with pytest.raises(AgentBindingResolutionError):
|
||||
AgentBindingResolver().resolve_one(event, [first, second])
|
||||
|
||||
class TestAgentRunContextProtocolV1:
|
||||
"""Test AgentRunContext Protocol v1 behavior."""
|
||||
|
||||
def test_sdk_context_event_required(self):
|
||||
"""Test that event is required in Protocol v1 context."""
|
||||
trigger = AgentTrigger(type="message.received")
|
||||
event = AgentEventContext(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
)
|
||||
input = AgentInput(text="Hello")
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
ctx = AgentRunContext(
|
||||
run_id="run_1",
|
||||
trigger=trigger,
|
||||
event=event,
|
||||
input=input,
|
||||
delivery=DeliveryContext(surface="platform"),
|
||||
resources=AgentResources(),
|
||||
runtime=AgentRuntimeContext(),
|
||||
)
|
||||
|
||||
assert ctx.event is not None
|
||||
assert ctx.event.event_type == "message.received"
|
||||
|
||||
def test_sdk_context_has_no_history_message_fields(self):
|
||||
"""AgentRunContext should not expose inline history message fields."""
|
||||
trigger = AgentTrigger(type="message.received")
|
||||
event = AgentEventContext(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
)
|
||||
input = AgentInput(text="Hello")
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
ctx = AgentRunContext(
|
||||
run_id="run_1",
|
||||
trigger=trigger,
|
||||
event=event,
|
||||
input=input,
|
||||
delivery=DeliveryContext(surface="platform"),
|
||||
resources=AgentResources(),
|
||||
runtime=AgentRuntimeContext(),
|
||||
)
|
||||
|
||||
assert "messages" not in AgentRunContext.model_fields
|
||||
assert "bootstrap" not in AgentRunContext.model_fields
|
||||
assert not hasattr(ctx, "bootstrap")
|
||||
|
||||
|
||||
class TestHostManagedHistoryNotInProtocol:
|
||||
"""Test that Host-managed history payloads are not in Protocol v1."""
|
||||
|
||||
def test_messages_not_in_sdk_context_top_level(self):
|
||||
"""AgentRunContext should not expose top-level history messages."""
|
||||
ctx_fields = AgentRunContext.model_fields.keys()
|
||||
|
||||
assert "messages" not in ctx_fields
|
||||
|
||||
|
||||
class TestSDKResultProtocolV1:
|
||||
"""Test SDK AgentRunResult for Protocol v1."""
|
||||
|
||||
def test_result_requires_run_id(self):
|
||||
"""Test result requires run_id for Protocol v1."""
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
result = AgentRunResult.message_completed(
|
||||
run_id="run_1",
|
||||
message=Message(role="assistant", content="Hello"),
|
||||
)
|
||||
|
||||
assert result.run_id == "run_1"
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_query():
|
||||
"""Create a mock query for testing."""
|
||||
query = Mock()
|
||||
query.query_id = 123
|
||||
query.bot_uuid = "bot-uuid-123"
|
||||
query.pipeline_uuid = "pipeline-uuid-456"
|
||||
query.launcher_type = Mock(value="person")
|
||||
query.launcher_id = "launcher-123"
|
||||
query.sender_id = "sender-123"
|
||||
query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": "plugin:test/plugin/runner",
|
||||
}
|
||||
}
|
||||
query.variables = {}
|
||||
|
||||
# Create a proper content element mock
|
||||
content_elem = Mock(spec=['type', 'text', 'model_dump'])
|
||||
content_elem.type = 'text'
|
||||
content_elem.text = 'Hello world'
|
||||
content_elem.model_dump = Mock(return_value={'type': 'text', 'text': 'Hello world'})
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = [content_elem]
|
||||
|
||||
# Create message_chain mock
|
||||
message_chain = Mock()
|
||||
message_chain.message_id = 789
|
||||
message_chain.model_dump = Mock(return_value={'message_id': 789, 'components': []})
|
||||
query.message_chain = message_chain
|
||||
|
||||
query.message_event = None
|
||||
|
||||
# Mock session with proper conversation
|
||||
query.session = Mock()
|
||||
query.session.launcher_type = Mock(value="person")
|
||||
query.session.launcher_id = "launcher-123"
|
||||
query.session.using_conversation = Mock()
|
||||
query.session.using_conversation.uuid = "conv-uuid-123"
|
||||
|
||||
# Mock use_funcs (empty list by default)
|
||||
query.use_funcs = []
|
||||
query.use_llm_model_uuid = None
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_no_session():
|
||||
"""Create a mock Query without session."""
|
||||
query = Mock()
|
||||
query.query_id = 456
|
||||
query.bot_uuid = "bot-uuid-456"
|
||||
query.pipeline_uuid = "pipeline-uuid-789"
|
||||
query.launcher_type = Mock(value="person")
|
||||
query.launcher_id = "launcher-456"
|
||||
query.sender_id = "sender-456"
|
||||
query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": "plugin:test/plugin/runner",
|
||||
}
|
||||
}
|
||||
query.variables = {}
|
||||
|
||||
# Create a proper content element mock
|
||||
content_elem = Mock(spec=['type', 'text', 'model_dump'])
|
||||
content_elem.type = 'text'
|
||||
content_elem.text = 'Test message'
|
||||
content_elem.model_dump = Mock(return_value={'type': 'text', 'text': 'Test message'})
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = [content_elem]
|
||||
|
||||
message_chain = Mock()
|
||||
message_chain.message_id = -1
|
||||
message_chain.model_dump = Mock(return_value={'message_id': -1, 'components': []})
|
||||
query.message_chain = message_chain
|
||||
|
||||
query.message_event = None
|
||||
query.session = None
|
||||
|
||||
# Mock use_funcs
|
||||
query.use_funcs = []
|
||||
query.use_llm_model_uuid = None
|
||||
|
||||
return query
|
||||
@@ -0,0 +1,795 @@
|
||||
"""Tests for EventLog, Transcript, and history/event APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.host_models import (
|
||||
AgentEventEnvelope,
|
||||
AgentBinding,
|
||||
BindingScope,
|
||||
ResourcePolicy,
|
||||
StatePolicy,
|
||||
DeliveryPolicy,
|
||||
)
|
||||
from langbot.pkg.agent.runner.event_log_store import EventLogStore
|
||||
from langbot.pkg.agent.runner.transcript_store import TranscriptStore
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import (
|
||||
ActorContext,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
|
||||
|
||||
def make_event_envelope(
|
||||
event_id: str = "evt_1",
|
||||
event_type: str = "message.received",
|
||||
conversation_id: str | None = "conv_1",
|
||||
actor_id: str | None = "user_1",
|
||||
input_text: str = "Hello",
|
||||
) -> AgentEventEnvelope:
|
||||
"""Create a test event envelope."""
|
||||
return AgentEventEnvelope(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
event_time=1700000000,
|
||||
source="platform",
|
||||
bot_id="bot_1",
|
||||
workspace_id=None,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=None,
|
||||
actor=ActorContext(
|
||||
actor_type="user",
|
||||
actor_id=actor_id,
|
||||
actor_name="Test User",
|
||||
) if actor_id else None,
|
||||
subject=None,
|
||||
input=AgentInput(text=input_text),
|
||||
delivery=DeliveryContext(surface="test"),
|
||||
)
|
||||
|
||||
|
||||
def make_binding(runner_id: str = "plugin:test/plugin/runner") -> AgentBinding:
|
||||
"""Create a test binding."""
|
||||
return AgentBinding(
|
||||
binding_id="binding_1",
|
||||
scope=BindingScope(scope_type="agent", scope_id="pipeline_1"),
|
||||
event_types=["message.received"],
|
||||
runner_id=runner_id,
|
||||
runner_config={},
|
||||
resource_policy=ResourcePolicy(),
|
||||
state_policy=StatePolicy(),
|
||||
delivery_policy=DeliveryPolicy(),
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
class TestEventLogStore:
|
||||
"""Test EventLogStore operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event(self, mock_db_engine):
|
||||
"""Test appending an event to EventLog."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
bot_id="bot_1",
|
||||
conversation_id="conv_1",
|
||||
actor_type="user",
|
||||
actor_id="user_1",
|
||||
input_summary="Hello world",
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
assert event_id == "evt_1"
|
||||
stored_event = mock_session.add.call_args.args[0]
|
||||
assert stored_event.metadata_json is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_stores_metadata_json(self, mock_db_engine):
|
||||
"""EventLog metadata records steering dispatch/audit facts."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_steering",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
metadata={
|
||||
"steering": {
|
||||
"status": "queued",
|
||||
"claimed_by_run_id": "run_1",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert event_id == "evt_steering"
|
||||
stored_event = mock_session.add.call_args.args[0]
|
||||
assert '"status": "queued"' in stored_event.metadata_json
|
||||
assert '"claimed_by_run_id": "run_1"' in stored_event.metadata_json
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_truncates_input_summary(self, mock_db_engine):
|
||||
"""Test that long input summaries are truncated."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
long_text = "x" * 2000
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_2",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
input_summary=long_text,
|
||||
)
|
||||
|
||||
assert event_id == "evt_2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events_with_conversation_filter(self, mock_db_engine):
|
||||
"""Test paging events with conversation_id filter."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
|
||||
class TestTranscriptStore:
|
||||
"""Test TranscriptStore operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript(self, mock_db_engine):
|
||||
"""Test appending a transcript item."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Mock _get_next_seq
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_1",
|
||||
conversation_id="conv_1",
|
||||
role="user",
|
||||
content="Hello",
|
||||
)
|
||||
|
||||
assert transcript_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript_with_attachments(self, mock_db_engine):
|
||||
"""Test appending transcript with attachment refs."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_2",
|
||||
conversation_id="conv_1",
|
||||
role="assistant",
|
||||
content="Here's an image",
|
||||
attachment_refs=[
|
||||
{"id": "att_1", "type": "image", "url": "http://example.com/img.png"}
|
||||
],
|
||||
)
|
||||
|
||||
assert transcript_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_backward(self, mock_db_engine):
|
||||
"""Test paging transcript backward (older items)."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
direction="backward",
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
|
||||
"""Test that transcript paging has a hard limit."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
# Request more than the hard limit
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
limit=200, # Request 200, but hard limit is 100
|
||||
)
|
||||
|
||||
# The store should cap at 100
|
||||
assert len(items) <= store.HARD_LIMIT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript(self, mock_db_engine):
|
||||
"""Test searching transcript."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_1",
|
||||
query_text="database",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert isinstance(items, list)
|
||||
|
||||
|
||||
class TestHistoryPageAuthorization:
|
||||
"""Test history.page authorization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_requires_run_id(self, mock_handler, mock_db_engine):
|
||||
"""Test history.page requires run_id."""
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
# Mock call_action to simulate the handler
|
||||
result = await mock_handler.call_action(
|
||||
PluginToRuntimeAction.HISTORY_PAGE,
|
||||
{"run_id": None},
|
||||
)
|
||||
|
||||
# Should return error
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_validates_conversation_scope(self, mock_db_engine):
|
||||
"""Test history.page only allows access to run's conversation."""
|
||||
# This test verifies the authorization logic
|
||||
# The actual implementation validates conversation_id matches session
|
||||
session_registry = get_session_registry()
|
||||
|
||||
await session_registry.register(
|
||||
run_id="run_1",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
query_id=None,
|
||||
plugin_identity="test/plugin",
|
||||
resources={"models": [], "tools": [], "knowledge_bases": [], "storage": {"plugin_storage": True}},
|
||||
conversation_id="conv_1",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_1")
|
||||
assert session is not None
|
||||
assert session["authorization"]["conversation_id"] == "conv_1"
|
||||
|
||||
# Cleanup
|
||||
await session_registry.unregister("run_1")
|
||||
|
||||
|
||||
class TestEventGetAuthorization:
|
||||
"""Test event.get authorization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_requires_run_id(self, mock_handler):
|
||||
"""Test event.get requires run_id."""
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
result = await mock_handler.call_action(
|
||||
PluginToRuntimeAction.EVENT_GET,
|
||||
{"run_id": None, "event_id": "evt_1"},
|
||||
)
|
||||
|
||||
# Should return error
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
|
||||
class TestContextAccessPopulation:
|
||||
"""Test ContextAccess population in build_context_from_event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
|
||||
"""Test ContextAccess shows available APIs based on permissions."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
cursor = await store.get_latest_cursor("conv_1")
|
||||
# Should return None or a cursor string
|
||||
assert cursor is None or isinstance(cursor, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_shows_has_history_before(self, mock_db_engine):
|
||||
"""Test ContextAccess indicates if history exists."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar.return_value = 0
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
has_history = await store.has_history_before("conv_1", 10)
|
||||
assert isinstance(has_history, bool)
|
||||
|
||||
|
||||
class TestEventLogStoreRealSQLite:
|
||||
"""Test EventLogStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_get_event_round_trip(self, db_engine):
|
||||
"""Test append_event -> get_event round trip with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append event
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_real_001",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
bot_id="bot_001",
|
||||
conversation_id="conv_001",
|
||||
actor_type="user",
|
||||
actor_id="user_001",
|
||||
actor_name="Test User",
|
||||
input_summary="Hello world",
|
||||
run_id="run_001",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
assert event_id == "evt_real_001"
|
||||
|
||||
# Get event
|
||||
event = await store.get_event(event_id)
|
||||
assert event is not None
|
||||
assert event["event_id"] == "evt_real_001"
|
||||
assert event["event_type"] == "message.received"
|
||||
assert event["source"] == "platform"
|
||||
assert event["conversation_id"] == "conv_001"
|
||||
assert event["actor_type"] == "user"
|
||||
assert event["actor_id"] == "user_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events(self, db_engine):
|
||||
"""Test page_events with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append multiple events
|
||||
for i in range(5):
|
||||
await store.append_event(
|
||||
event_id=f"evt_real_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_001",
|
||||
input_summary=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page events
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_001",
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert has_more is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append events
|
||||
for i in range(3):
|
||||
await store.append_event(
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cursor",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_events_older_than(self, db_engine):
|
||||
"""EventLog cleanup removes only rows older than the cutoff."""
|
||||
import sqlalchemy
|
||||
from langbot.pkg.entity.persistence.event_log import EventLog
|
||||
|
||||
store = EventLogStore(db_engine)
|
||||
cutoff = datetime.datetime.utcnow()
|
||||
await store.append_event(
|
||||
event_id="evt_cleanup_old",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cleanup",
|
||||
)
|
||||
await store.append_event(
|
||||
event_id="evt_cleanup_new",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cleanup",
|
||||
)
|
||||
async with store._session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(EventLog)
|
||||
.where(EventLog.event_id == "evt_cleanup_old")
|
||||
.values(created_at=cutoff - datetime.timedelta(days=2))
|
||||
)
|
||||
await session.execute(
|
||||
sqlalchemy.update(EventLog)
|
||||
.where(EventLog.event_id == "evt_cleanup_new")
|
||||
.values(created_at=cutoff + datetime.timedelta(days=2))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
removed = await store.cleanup_events_older_than(cutoff)
|
||||
|
||||
assert removed == 1
|
||||
assert await store.get_event("evt_cleanup_old") is None
|
||||
assert await store.get_event("evt_cleanup_new") is not None
|
||||
|
||||
|
||||
class TestTranscriptStoreRealSQLite:
|
||||
"""Test TranscriptStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_page_transcript_round_trip(self, db_engine):
|
||||
"""Test append_transcript -> page_transcript round trip with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_real_{i:03d}",
|
||||
event_id=f"evt_{i:03d}",
|
||||
conversation_id="conv_001",
|
||||
role="user" if i % 2 == 0 else "assistant",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page transcript
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_001",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert items[0]["conversation_id"] == "conv_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_legacy_provider_messages_projects_transcript_history(self, db_engine):
|
||||
"""Transcript is the canonical source; legacy Pipeline readers get a Message view."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_001",
|
||||
event_id="evt_view_001",
|
||||
conversation_id="conv_view",
|
||||
role="user",
|
||||
content="User text",
|
||||
content_json={
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "User structured text"}],
|
||||
},
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_002",
|
||||
event_id="evt_view_002",
|
||||
conversation_id="conv_view",
|
||||
role="tool",
|
||||
item_type="tool_result",
|
||||
content="ignored tool result",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_view_003",
|
||||
event_id="evt_view_003",
|
||||
conversation_id="conv_view",
|
||||
role="assistant",
|
||||
content="Assistant text",
|
||||
)
|
||||
|
||||
messages = await store.get_legacy_provider_messages("conv_view")
|
||||
|
||||
assert [message.role for message in messages] == ["user", "assistant"]
|
||||
assert messages[0].content[0].text == "User structured text"
|
||||
assert messages[1].content == "Assistant text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_legacy_provider_messages_filters_scope(self, db_engine):
|
||||
"""Legacy Pipeline history projection must stay inside the current run scope."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_001",
|
||||
event_id="evt_scope_001",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="user",
|
||||
content="Current scope text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_002",
|
||||
event_id="evt_scope_002",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_002",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="assistant",
|
||||
content="Other bot text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_003",
|
||||
event_id="evt_scope_003",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_002",
|
||||
role="assistant",
|
||||
content="Other thread text",
|
||||
)
|
||||
|
||||
messages = await store.get_legacy_provider_messages(
|
||||
"conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
strict_thread=True,
|
||||
)
|
||||
|
||||
assert [message.content for message in messages] == ["Current scope text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript_real_db(self, db_engine):
|
||||
"""Test search_transcript with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_001",
|
||||
event_id="evt_search_001",
|
||||
conversation_id="conv_search",
|
||||
role="user",
|
||||
content="I want to learn about databases",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_002",
|
||||
event_id="evt_search_002",
|
||||
conversation_id="conv_search",
|
||||
role="assistant",
|
||||
content="Here is information about databases",
|
||||
)
|
||||
|
||||
# Search for "database"
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_search",
|
||||
query_text="database",
|
||||
)
|
||||
|
||||
# Should find at least one match
|
||||
assert len(items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor_real_db(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_cursor_{i:03d}",
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
conversation_id="conv_cursor",
|
||||
role="user",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_transcripts_older_than(self, db_engine):
|
||||
"""Transcript cleanup removes only rows older than the cutoff."""
|
||||
import sqlalchemy
|
||||
from langbot.pkg.entity.persistence.transcript import Transcript
|
||||
|
||||
store = TranscriptStore(db_engine)
|
||||
cutoff = datetime.datetime.utcnow()
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_cleanup_old",
|
||||
event_id="evt_cleanup_old",
|
||||
conversation_id="conv_cleanup",
|
||||
role="user",
|
||||
content="old",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_cleanup_new",
|
||||
event_id="evt_cleanup_new",
|
||||
conversation_id="conv_cleanup",
|
||||
role="assistant",
|
||||
content="new",
|
||||
)
|
||||
async with store._session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(Transcript)
|
||||
.where(Transcript.transcript_id == "trans_cleanup_old")
|
||||
.values(created_at=cutoff - datetime.timedelta(days=2))
|
||||
)
|
||||
await session.execute(
|
||||
sqlalchemy.update(Transcript)
|
||||
.where(Transcript.transcript_id == "trans_cleanup_new")
|
||||
.values(created_at=cutoff + datetime.timedelta(days=2))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
removed = await store.cleanup_transcripts_older_than(cutoff)
|
||||
items, _, _, _ = await store.page_transcript("conv_cleanup", limit=10)
|
||||
|
||||
assert removed == 1
|
||||
assert [item["content"] for item in items] == ["new"]
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_engine():
|
||||
"""Create a mock database engine for AsyncSession-based stores."""
|
||||
from unittest.mock import MagicMock
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine = MagicMock(spec=AsyncEngine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_handler():
|
||||
"""Create a mock handler for testing actions."""
|
||||
from langbot_plugin.runtime.io.handler import Handler
|
||||
|
||||
class MockHandler(Handler):
|
||||
def __init__(self):
|
||||
self._responses = {}
|
||||
|
||||
async def call_action(self, action, data, timeout=30):
|
||||
# Simulate error response for missing run_id
|
||||
if not data.get("run_id"):
|
||||
return {"ok": False, "message": "run_id is required"}
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
return MockHandler()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,323 @@
|
||||
"""Tests for AgentRunner history/event pull API authorization."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.event_log_store import EventLogStore
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.entity.persistence import event_log as event_log_model
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.page_results import (
|
||||
AgentEventRecord,
|
||||
EventPage,
|
||||
)
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
def __init__(self, db_engine):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry(monkeypatch):
|
||||
registry = AgentRunSessionRegistry()
|
||||
monkeypatch.setattr(
|
||||
'langbot.pkg.agent.runner.session_registry._global_registry',
|
||||
registry,
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
assert event_log_model.EventLog.__tablename__ == 'event_log'
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _handler(db_engine, session_registry):
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
fake_app = FakeApplication(db_engine)
|
||||
return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
|
||||
async def _register_session(
|
||||
session_registry,
|
||||
*,
|
||||
run_id='run_1',
|
||||
conversation_id='conv_1',
|
||||
bot_id=None,
|
||||
workspace_id=None,
|
||||
thread_id=None,
|
||||
available_apis=None,
|
||||
):
|
||||
await session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=None,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
conversation_id=conversation_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
available_apis=available_apis or {},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not authorized' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'conversation_id': 'conv_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_search': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value]
|
||||
|
||||
result = await history_search({
|
||||
'run_id': 'run_1',
|
||||
'query': 'hello',
|
||||
'filters': {'conversation_id': 'conv_other'},
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not authorized' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'conversation_id': 'conv_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_returns_sdk_record_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_get': True})
|
||||
store = EventLogStore(db_engine)
|
||||
event_id = await store.append_event(
|
||||
event_id='evt_projection_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
actor_type='user',
|
||||
actor_id='user_1',
|
||||
input_summary='hello',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_1',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_get = handler.actions[PluginToRuntimeAction.EVENT_GET.value]
|
||||
|
||||
result = await event_get({
|
||||
'run_id': 'run_1',
|
||||
'event_id': event_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
AgentEventRecord.model_validate(result.data)
|
||||
assert 'id' not in result.data
|
||||
assert 'input_json' not in result.data
|
||||
assert 'run_id' not in result.data
|
||||
assert 'runner_id' not in result.data
|
||||
assert result.data['seq'] is not None
|
||||
assert result.data['cursor'] == str(result.data['seq'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_returns_sdk_page_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_projection_page_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_other',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
page = EventPage.model_validate(result.data)
|
||||
assert len(page.items) == 1
|
||||
item = result.data['items'][0]
|
||||
assert 'id' not in item
|
||||
assert 'input_json' not in item
|
||||
assert 'run_id' not in item
|
||||
assert 'runner_id' not in item
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_filters_run_scope_thread_and_bot(session_registry, db_engine):
|
||||
from langbot.pkg.agent.runner.transcript_store import TranscriptStore
|
||||
|
||||
await _register_session(
|
||||
session_registry,
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'history_page': True},
|
||||
)
|
||||
store = TranscriptStore(db_engine)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_visible',
|
||||
event_id='evt_visible',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
content='visible',
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_other_bot',
|
||||
event_id='evt_other_bot',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_2',
|
||||
thread_id='thread_1',
|
||||
content='hidden bot',
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id='tr_other_thread',
|
||||
event_id='evt_other_thread',
|
||||
conversation_id='conv_1',
|
||||
role='user',
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_2',
|
||||
content='hidden thread',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
result = await history_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
assert [item['content'] for item in result.data['items']] == ['visible']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_filters_run_scope_thread_and_bot(session_registry, db_engine):
|
||||
await _register_session(
|
||||
session_registry,
|
||||
bot_id='bot_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'event_page': True},
|
||||
)
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_visible',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_1',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_1',
|
||||
)
|
||||
await store.append_event(
|
||||
event_id='evt_other_bot',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_2',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_1',
|
||||
)
|
||||
await store.append_event(
|
||||
event_id='evt_other_thread',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
bot_id='bot_1',
|
||||
conversation_id='conv_1',
|
||||
thread_id='thread_2',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
assert [item['event_id'] for item in result.data['items']] == ['evt_visible']
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Tests for agent runner ID parsing and formatting."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.id import (
|
||||
parse_runner_id,
|
||||
format_runner_id,
|
||||
RunnerIdParts,
|
||||
is_plugin_runner_id,
|
||||
)
|
||||
|
||||
|
||||
class TestRunnerIdParsing:
|
||||
"""Tests for parse_runner_id."""
|
||||
|
||||
def test_parse_plugin_runner_id(self):
|
||||
"""Parse valid plugin runner ID."""
|
||||
runner_id = 'plugin:langbot/local-agent/default'
|
||||
parts = parse_runner_id(runner_id)
|
||||
|
||||
assert parts.source == 'plugin'
|
||||
assert parts.plugin_author == 'langbot'
|
||||
assert parts.plugin_name == 'local-agent'
|
||||
assert parts.runner_name == 'default'
|
||||
|
||||
def test_parse_plugin_runner_id_complex_names(self):
|
||||
"""Parse plugin runner ID with complex names."""
|
||||
runner_id = 'plugin:alice/helpdesk-agent/ticket-handler'
|
||||
parts = parse_runner_id(runner_id)
|
||||
|
||||
assert parts.source == 'plugin'
|
||||
assert parts.plugin_author == 'alice'
|
||||
assert parts.plugin_name == 'helpdesk-agent'
|
||||
assert parts.runner_name == 'ticket-handler'
|
||||
|
||||
def test_parse_invalid_plugin_runner_id_missing_parts(self):
|
||||
"""Parse invalid plugin runner ID with missing parts."""
|
||||
runner_id = 'plugin:langbot/local-agent'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'Invalid plugin runner ID format' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_plugin_runner_id_empty_parts(self):
|
||||
"""Parse invalid plugin runner ID with empty parts."""
|
||||
runner_id = 'plugin://default'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'non-empty' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_runner_id_not_plugin(self):
|
||||
"""Parse invalid runner ID without plugin prefix."""
|
||||
runner_id = 'local-agent'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
assert 'Invalid runner ID format' in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_runner_id_empty_string(self):
|
||||
"""Parse empty runner ID."""
|
||||
runner_id = ''
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parse_runner_id(runner_id)
|
||||
|
||||
|
||||
class TestRunnerIdFormatting:
|
||||
"""Tests for format_runner_id."""
|
||||
|
||||
def test_format_plugin_runner_id(self):
|
||||
"""Format plugin runner ID."""
|
||||
runner_id = format_runner_id(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert runner_id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
def test_format_invalid_source(self):
|
||||
"""Format runner ID with invalid source."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
format_runner_id(
|
||||
source='builtin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert 'Invalid runner source' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRunnerIdParts:
|
||||
"""Tests for RunnerIdParts dataclass."""
|
||||
|
||||
def test_get_plugin_id(self):
|
||||
"""Get plugin ID from parts."""
|
||||
parts = RunnerIdParts(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert parts.to_plugin_id() == 'langbot/local-agent'
|
||||
|
||||
def test_frozen_dataclass(self):
|
||||
"""RunnerIdParts should be immutable."""
|
||||
parts = RunnerIdParts(
|
||||
source='plugin',
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception): # FrozenInstanceError
|
||||
parts.plugin_author = 'other'
|
||||
|
||||
|
||||
class TestIsPluginRunnerId:
|
||||
"""Tests for is_plugin_runner_id."""
|
||||
|
||||
def test_is_plugin_runner_id_true(self):
|
||||
"""Check plugin runner ID returns True."""
|
||||
assert is_plugin_runner_id('plugin:langbot/local-agent/default') is True
|
||||
|
||||
def test_is_plugin_runner_id_false(self):
|
||||
"""Check non-plugin runner ID returns False."""
|
||||
assert is_plugin_runner_id('local-agent') is False
|
||||
assert is_plugin_runner_id('builtin:local-agent') is False
|
||||
assert is_plugin_runner_id('') is False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
"""Tests for agent runner registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.registry import AgentRunnerRegistry
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.errors import RunnerNotFoundError, RunnerNotAuthorizedError
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
|
||||
def __init__(self):
|
||||
class FakeLogger:
|
||||
def info(self, msg):
|
||||
pass
|
||||
|
||||
def debug(self, msg):
|
||||
pass
|
||||
|
||||
def warning(self, msg):
|
||||
pass
|
||||
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
self.logger = FakeLogger()
|
||||
|
||||
class FakePluginConnector:
|
||||
is_enable_plugin = True
|
||||
|
||||
async def list_agent_runners(self, bound_plugins=None):
|
||||
# Return sample runner data
|
||||
return [
|
||||
{
|
||||
'plugin_author': 'langbot',
|
||||
'plugin_name': 'local-agent',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'id': 'plugin:langbot/local-agent/default',
|
||||
'name': 'default',
|
||||
'label': {'en_US': 'Local Agent'},
|
||||
'capabilities': {'streaming': True},
|
||||
'permissions': {},
|
||||
'config_schema': [],
|
||||
},
|
||||
},
|
||||
{
|
||||
'plugin_author': 'alice',
|
||||
'plugin_name': 'my-agent',
|
||||
'runner_name': 'custom',
|
||||
'manifest': {
|
||||
'id': 'plugin:alice/my-agent/custom',
|
||||
'name': 'custom',
|
||||
'label': {'en_US': 'Custom Agent'},
|
||||
'capabilities': {},
|
||||
'permissions': {},
|
||||
'config_schema': [{'name': 'param1', 'type': 'string'}],
|
||||
},
|
||||
},
|
||||
# Invalid runner - wrong kind
|
||||
{
|
||||
'plugin_author': 'bad',
|
||||
'plugin_name': 'wrong-kind',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'kind': 'Tool', # Wrong kind
|
||||
'metadata': {},
|
||||
'spec': {},
|
||||
},
|
||||
},
|
||||
# Invalid runner - missing name
|
||||
{
|
||||
'plugin_author': 'bad',
|
||||
'plugin_name': 'missing-name',
|
||||
'runner_name': 'default',
|
||||
'manifest': {
|
||||
'kind': 'AgentRunner',
|
||||
'metadata': {}, # No name
|
||||
'spec': {},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
self.plugin_connector = FakePluginConnector()
|
||||
|
||||
|
||||
class TestRegistryDiscovery:
|
||||
"""Tests for runner discovery."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_valid_runners(self):
|
||||
"""Discover valid runners from plugin runtime."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
runners = await registry.list_runners(use_cache=False)
|
||||
|
||||
# Should find 2 valid runners (langbot/local-agent and alice/my-agent)
|
||||
assert len(runners) == 2
|
||||
|
||||
ids = [r.id for r in runners]
|
||||
assert 'plugin:langbot/local-agent/default' in ids
|
||||
assert 'plugin:alice/my-agent/custom' in ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_caches_results(self):
|
||||
"""Discovery should cache results."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# First discovery
|
||||
runners1 = await registry.list_runners(use_cache=True)
|
||||
|
||||
# Second call should use cache
|
||||
runners2 = await registry.list_runners(use_cache=True)
|
||||
|
||||
assert registry._cache is not None
|
||||
assert len(runners1) == len(runners2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_handles_plugin_disabled(self):
|
||||
"""Discovery returns empty when plugin system disabled."""
|
||||
ap = FakeApplication()
|
||||
ap.plugin_connector.is_enable_plugin = False
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
runners = await registry.list_runners(use_cache=False)
|
||||
|
||||
assert runners == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_not_polluted_by_bound_plugins(self):
|
||||
"""Cache should contain ALL runners, not filtered by bound_plugins.
|
||||
|
||||
Regression test: get(bound_plugins=["a/b"]) should not pollute cache,
|
||||
so subsequent list_runners(bound_plugins=None) should return all runners.
|
||||
"""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# First: get with bound_plugins filter (should not pollute cache)
|
||||
descriptor = await registry.get(
|
||||
'plugin:langbot/local-agent/default',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
assert descriptor.id == 'plugin:langbot/local-agent/default'
|
||||
|
||||
# Cache should contain ALL runners (both langbot and alice)
|
||||
assert registry._cache is not None
|
||||
assert len(registry._cache) == 2 # Both runners in cache
|
||||
assert 'plugin:langbot/local-agent/default' in registry._cache
|
||||
assert 'plugin:alice/my-agent/custom' in registry._cache
|
||||
|
||||
# Second: list_runners without filter should return ALL runners
|
||||
all_runners = await registry.list_runners(bound_plugins=None, use_cache=True)
|
||||
assert len(all_runners) == 2 # Both runners returned
|
||||
|
||||
# Third: list_runners with different filter should work correctly
|
||||
alice_runners = await registry.list_runners(bound_plugins=['alice/my-agent'], use_cache=True)
|
||||
assert len(alice_runners) == 1
|
||||
assert alice_runners[0].id == 'plugin:alice/my-agent/custom'
|
||||
|
||||
|
||||
class TestRegistryGet:
|
||||
"""Tests for getting specific runner."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_runner(self):
|
||||
"""Get existing runner by ID."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
descriptor = await registry.get('plugin:langbot/local-agent/default')
|
||||
|
||||
assert descriptor.id == 'plugin:langbot/local-agent/default'
|
||||
assert descriptor.plugin_author == 'langbot'
|
||||
assert descriptor.plugin_name == 'local-agent'
|
||||
assert descriptor.runner_name == 'default'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_runner(self):
|
||||
"""Get nonexistent runner raises RunnerNotFoundError."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
with pytest.raises(RunnerNotFoundError) as exc_info:
|
||||
await registry.get('plugin:notexist/unknown/default')
|
||||
|
||||
assert exc_info.value.runner_id == 'plugin:notexist/unknown/default'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runner_with_bound_plugins_filter(self):
|
||||
"""Get runner with bound plugins authorization."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
# Authorized - langbot plugin in bound list
|
||||
descriptor = await registry.get(
|
||||
'plugin:langbot/local-agent/default',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
assert descriptor is not None
|
||||
|
||||
# Not authorized - plugin not in bound list
|
||||
with pytest.raises(RunnerNotAuthorizedError):
|
||||
await registry.get(
|
||||
'plugin:alice/my-agent/custom',
|
||||
bound_plugins=['langbot/local-agent'],
|
||||
)
|
||||
|
||||
|
||||
class TestRegistryMetadataForPipeline:
|
||||
"""Tests for get_runner_metadata_for_pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_options_and_stages(self):
|
||||
"""Get metadata options and stages for pipeline UI."""
|
||||
ap = FakeApplication()
|
||||
registry = AgentRunnerRegistry(ap)
|
||||
|
||||
options, stages = await registry.get_runner_metadata_for_pipeline()
|
||||
|
||||
# Should have options for each runner
|
||||
assert len(options) == 2
|
||||
option_ids = [o['name'] for o in options]
|
||||
assert 'plugin:langbot/local-agent/default' in option_ids
|
||||
assert 'plugin:alice/my-agent/custom' in option_ids
|
||||
|
||||
# Config comes from the typed manifest.
|
||||
assert len(stages) == 1
|
||||
assert stages[0]['name'] == 'plugin:alice/my-agent/custom'
|
||||
assert stages[0]['config'][0]['name'] == 'param1'
|
||||
assert stages[0]['config'][0]['type'] == 'string'
|
||||
assert stages[0]['config'][0]['id'] == 'plugin:alice/my-agent/custom.param1'
|
||||
|
||||
|
||||
class TestDescriptorValidation:
|
||||
"""Tests for descriptor validation."""
|
||||
|
||||
def test_validate_runner_descriptor(self):
|
||||
"""Validate correctly built descriptor."""
|
||||
descriptor = AgentRunnerDescriptor(
|
||||
id='plugin:test/my-runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
)
|
||||
|
||||
assert descriptor.id == 'plugin:test/my-runner/default'
|
||||
assert descriptor.get_plugin_id() == 'test/my-runner'
|
||||
assert 'protocol_version' not in AgentRunnerDescriptor.model_fields
|
||||
|
||||
def test_descriptor_capabilities(self):
|
||||
"""Descriptor capability helper methods."""
|
||||
descriptor = AgentRunnerDescriptor(
|
||||
id='plugin:test/my-runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True, 'tool_calling': False},
|
||||
)
|
||||
|
||||
assert descriptor.supports_streaming() is True
|
||||
assert descriptor.supports_tool_calling() is False
|
||||
assert descriptor.supports_knowledge_retrieval() is False
|
||||
@@ -0,0 +1,400 @@
|
||||
"""Tests for AgentResourceBuilder."""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.binding_resolver import AgentBindingResolver
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
|
||||
|
||||
|
||||
RUNNER_ID = 'plugin:test/runner/default'
|
||||
FULL_PERMISSIONS = {
|
||||
'models': ['invoke', 'stream', 'rerank'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'storage': ['plugin', 'workspace'],
|
||||
}
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
*,
|
||||
config_schema: list[dict] | None = None,
|
||||
capabilities: dict | None = None,
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id=RUNNER_ID,
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
capabilities=capabilities or {},
|
||||
permissions=permissions if permissions is not None else FULL_PERMISSIONS,
|
||||
config_schema=config_schema or [],
|
||||
)
|
||||
|
||||
|
||||
def make_model(model_type='llm', provider='test-provider'):
|
||||
return SimpleNamespace(
|
||||
model_entity=SimpleNamespace(model_type=model_type),
|
||||
provider_entity=SimpleNamespace(name=provider),
|
||||
)
|
||||
|
||||
|
||||
def make_query(
|
||||
runner_config: dict,
|
||||
*,
|
||||
variables: dict | None = None,
|
||||
use_llm_model_uuid=None,
|
||||
use_funcs: list | None = None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
query_id=1,
|
||||
bot_uuid='bot_001',
|
||||
launcher_type='person',
|
||||
launcher_id='launcher_001',
|
||||
sender_id='sender_001',
|
||||
message_event=None,
|
||||
message_chain=None,
|
||||
user_message=None,
|
||||
session=None,
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'id': RUNNER_ID},
|
||||
'runner_config': {RUNNER_ID: runner_config},
|
||||
},
|
||||
},
|
||||
variables=variables or {},
|
||||
use_llm_model_uuid=use_llm_model_uuid,
|
||||
use_funcs=use_funcs or [],
|
||||
pipeline_uuid='pipeline_001',
|
||||
)
|
||||
|
||||
|
||||
async def build_resources(app, query, descriptor):
|
||||
event = QueryEntryAdapter.query_to_event(query)
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(query, descriptor.id)
|
||||
binding = AgentBindingResolver().resolve_one(event, [agent_config])
|
||||
return await AgentResourceBuilder(app).build_resources_from_binding(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.model_mgr = Mock()
|
||||
mock_app.rag_mgr = Mock()
|
||||
mock_app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(return_value=None)
|
||||
return mock_app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app):
|
||||
"""DynamicForm model selectors should become run-scoped authorized models."""
|
||||
llm_models = {
|
||||
'primary': make_model(),
|
||||
'fallback': make_model(),
|
||||
'aux': make_model(provider='aux-provider'),
|
||||
}
|
||||
rerank_models = {
|
||||
'rerank': make_model(model_type='rerank', provider='rerank-provider'),
|
||||
}
|
||||
|
||||
async def get_model_by_uuid(model_uuid):
|
||||
return llm_models.get(model_uuid)
|
||||
|
||||
async def get_rerank_model_by_uuid(model_uuid):
|
||||
return rerank_models.get(model_uuid)
|
||||
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'aux-model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': {'primary': 'primary', 'fallbacks': ['fallback', 'primary']},
|
||||
'aux-model': 'aux',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'primary', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'fallback', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'aux', 'model_type': 'llm', 'provider': 'aux-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_from_config_without_manifest_acl(app):
|
||||
"""Config-selected models are not projected without manifest model permissions."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=make_model(model_type='rerank'))
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={},
|
||||
)
|
||||
query = make_query({
|
||||
'model': {'primary': 'primary', 'fallbacks': ['fallback']},
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
|
||||
"""Config-selected model references are projected regardless of method granularity."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_resources_accepts_dynamic_form_type_aliases(app):
|
||||
"""Frontend DynamicForm aliases should resolve to runtime resource grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
|
||||
async def get_kb(kb_uuid):
|
||||
return SimpleNamespace(
|
||||
uuid=kb_uuid,
|
||||
get_name=lambda: f'name-{kb_uuid}',
|
||||
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||
)
|
||||
|
||||
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'select-llm-model'},
|
||||
{'name': 'knowledge-bases', 'type': 'select-knowledge-bases'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm_alias',
|
||||
'knowledge-bases': ['kb_alias'],
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
]
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_manifest_permission_narrows_binding(app):
|
||||
"""Manifest model permissions narrower than binding should remove LLM grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'models': ['rerank'],
|
||||
},
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_deduplicates_query_and_config_models(app):
|
||||
"""A model selected by both preproc and runner config should appear once."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=None)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'model': {'primary': 'primary', 'fallbacks': ['fallback']}},
|
||||
variables={'_fallback_model_uuids': ['fallback']},
|
||||
use_llm_model_uuid='primary',
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert [model['model_id'] for model in resources['models']] == ['primary', 'fallback']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_tools_authorizes_query_declared_tools(app):
|
||||
"""Tools discovered by Pipeline preprocessing become run-scoped authorized resources."""
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'tool_calling': True},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
use_funcs=[
|
||||
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
|
||||
SimpleNamespace(name='qa_mcp_echo'),
|
||||
],
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['tools'] == [
|
||||
{
|
||||
'tool_name': 'qa_plugin_echo',
|
||||
'tool_type': None,
|
||||
'description': None,
|
||||
'operations': ['detail', 'call'],
|
||||
},
|
||||
{
|
||||
'tool_name': 'qa_mcp_echo',
|
||||
'tool_type': None,
|
||||
'description': None,
|
||||
'operations': ['detail', 'call'],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_tools_manifest_permission_denies_binding_tools(app):
|
||||
"""Binding tool grants should be removed when manifest does not request tools."""
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'tool_calling': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'tools': [],
|
||||
},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
use_funcs=[
|
||||
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
|
||||
],
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['tools'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'knowledge-bases': ['kb_config']},
|
||||
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||
)
|
||||
|
||||
async def get_kb(kb_uuid):
|
||||
return SimpleNamespace(
|
||||
uuid=kb_uuid,
|
||||
get_name=lambda: f'name-{kb_uuid}',
|
||||
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||
)
|
||||
|
||||
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
{'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_manifest_permission_denies_binding_kbs(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'knowledge_bases': [],
|
||||
},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'knowledge-bases': ['kb_config']},
|
||||
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['knowledge_bases'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_storage_intersects_manifest_and_binding_policy(app):
|
||||
descriptor = make_descriptor(
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
query = make_query({})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['storage'] == {
|
||||
'plugin_storage': True,
|
||||
'workspace_storage': False,
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Tests for agent runner result normalizer."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.result_normalizer import AgentResultNormalizer
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.errors import RunnerExecutionError, RunnerProtocolError
|
||||
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self):
|
||||
class FakeLogger:
|
||||
def __init__(self):
|
||||
self.warnings = []
|
||||
|
||||
def info(self, msg):
|
||||
pass
|
||||
def debug(self, msg):
|
||||
pass
|
||||
def warning(self, msg):
|
||||
self.warnings.append(msg)
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
self.logger = FakeLogger()
|
||||
|
||||
|
||||
def make_descriptor():
|
||||
"""Create a test descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id='plugin:langbot/local-agent/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Local Agent', 'zh_Hans': '内置 Agent'},
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True},
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeMessageDelta:
|
||||
"""Tests for normalizing message.delta results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_delta_text(self):
|
||||
"""Normalize message.delta with text chunk."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.delta',
|
||||
'data': {
|
||||
'chunk': {
|
||||
'role': 'assistant',
|
||||
'content': 'Hello',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.MessageChunk)
|
||||
assert result.role == 'assistant'
|
||||
assert result.content == 'Hello'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_delta_missing_chunk(self):
|
||||
"""Invalid message.delta payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.delta',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeMessageCompleted:
|
||||
"""Tests for normalizing message.completed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_completed(self):
|
||||
"""Normalize message.completed with full message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Complete response',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.Message)
|
||||
assert result.role == 'assistant'
|
||||
assert result.content == 'Complete response'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_completed_missing_message(self):
|
||||
"""Invalid message.completed payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeRunCompleted:
|
||||
"""Tests for normalizing run.completed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_completed_with_message(self):
|
||||
"""Normalize run.completed with final message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Final response',
|
||||
},
|
||||
'finish_reason': 'stop',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, provider_message.Message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_completed_without_message(self):
|
||||
"""Normalize run.completed without message."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.completed',
|
||||
'data': {
|
||||
'finish_reason': 'stop',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeRunFailed:
|
||||
"""Tests for normalizing run.failed results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_run_failed(self):
|
||||
"""Normalize run.failed raises RunnerExecutionError."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'run.failed',
|
||||
'data': {
|
||||
'error': 'Upstream timeout',
|
||||
'code': 'upstream.timeout',
|
||||
'retryable': True,
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerExecutionError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert exc_info.value.runner_id == 'plugin:langbot/local-agent/default'
|
||||
assert exc_info.value.retryable is True
|
||||
assert 'timeout' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestNormalizeNonMessageResults:
|
||||
"""Tests for normalizing non-message results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_tool_call_started(self):
|
||||
"""Normalize tool.call.started returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'tool.call.started',
|
||||
'data': {
|
||||
'tool_call_id': 'call_1',
|
||||
'tool_name': 'weather',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_tool_call_completed(self):
|
||||
"""Normalize tool.call.completed returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'tool.call.completed',
|
||||
'data': {
|
||||
'tool_call_id': 'call_1',
|
||||
'tool_name': 'weather',
|
||||
'result': {'temp': 20},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_state_updated(self):
|
||||
"""Normalize state.updated returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'state.updated',
|
||||
'data': {
|
||||
'scope': 'conversation',
|
||||
'key': 'external_conversation_id',
|
||||
'value': 'abc123',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_action_requested(self):
|
||||
"""Normalize action.requested returns None (EBA reserved)."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'action.requested',
|
||||
'data': {
|
||||
'action': 'platform.message.edit',
|
||||
'payload': {},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_state_updated_payload_is_dropped(self):
|
||||
"""Invalid state.updated payload returns None with a warning."""
|
||||
app = FakeApplication()
|
||||
normalizer = AgentResultNormalizer(app)
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result = await normalizer.normalize(
|
||||
{
|
||||
'type': 'state.updated',
|
||||
'data': {
|
||||
'scope': 'invalid',
|
||||
'key': 'k',
|
||||
'value': 'v',
|
||||
},
|
||||
},
|
||||
descriptor,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert app.logger.warnings
|
||||
|
||||
class TestNormalizeInvalidResults:
|
||||
"""Tests for handling invalid results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_missing_type(self):
|
||||
"""Normalize result without type."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'data': {},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert 'Missing result type' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_unknown_type(self):
|
||||
"""Normalize unknown type returns None."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result_dict = {
|
||||
'type': 'unknown_type',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_legacy_type_returns_none(self):
|
||||
"""Legacy types (chunk, text, finish) are now treated as unknown."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
# chunk is now unknown
|
||||
result_dict = {
|
||||
'type': 'chunk',
|
||||
'data': {
|
||||
'message_chunk': {
|
||||
'role': 'assistant',
|
||||
'content': 'Legacy chunk',
|
||||
},
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
# text is now unknown
|
||||
result_dict = {
|
||||
'type': 'text',
|
||||
'data': {
|
||||
'content': 'Legacy text',
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
# finish is now unknown
|
||||
result_dict = {
|
||||
'type': 'finish',
|
||||
'data': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Legacy finish',
|
||||
},
|
||||
},
|
||||
}
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,430 @@
|
||||
"""Tests for RunLedgerStore host primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore
|
||||
from langbot.pkg.entity.persistence.agent_run import AgentRun
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(tmp_path):
|
||||
db_path = tmp_path / 'run_ledger_store.db'
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(db_engine):
|
||||
return RunLedgerStore(db_engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_queued_run_claim_renew_release(store):
|
||||
run = await store.create_run(
|
||||
run_id='run-queued',
|
||||
event_id='evt-1',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=10,
|
||||
requested_runtime_id='runtime-a',
|
||||
)
|
||||
|
||||
assert run['status'] == 'queued'
|
||||
assert run['started_at'] is None
|
||||
assert run['queue_name'] == 'default'
|
||||
assert run['priority'] == 10
|
||||
assert run['requested_runtime_id'] == 'runtime-a'
|
||||
|
||||
assert await store.claim_next_run(runtime_id='runtime-b', queue_name='default') is None
|
||||
|
||||
claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=30)
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed['run_id'] == 'run-queued'
|
||||
assert claimed['status'] == 'claimed'
|
||||
assert claimed['claimed_by_runtime_id'] == 'runtime-a'
|
||||
assert claimed['claim_token']
|
||||
assert claimed['dispatch_attempts'] == 1
|
||||
assert claimed['claim_lease_expires_at'] is not None
|
||||
assert claimed['last_claimed_at'] is not None
|
||||
|
||||
token = claimed['claim_token']
|
||||
assert await store.renew_claim(run_id='run-queued', claim_token='wrong-token') is None
|
||||
|
||||
renewed = await store.renew_claim(run_id='run-queued', claim_token=token, lease_seconds=90)
|
||||
|
||||
assert renewed is not None
|
||||
assert 'claim_token' not in renewed
|
||||
assert renewed['claim_lease_expires_at'] >= claimed['claim_lease_expires_at']
|
||||
|
||||
released = await store.release_claim(
|
||||
run_id='run-queued',
|
||||
claim_token=token,
|
||||
status='queued',
|
||||
status_reason='runtime released capacity',
|
||||
)
|
||||
|
||||
assert released is not None
|
||||
assert released['status'] == 'queued'
|
||||
assert released['status_reason'] == 'runtime released capacity'
|
||||
assert released['claimed_by_runtime_id'] is None
|
||||
assert 'claim_token' not in released
|
||||
assert released['claim_lease_expires_at'] is None
|
||||
assert released['dispatch_attempts'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_next_run_applies_scope_filters(store):
|
||||
await store.create_run(
|
||||
run_id='run-other-runner',
|
||||
event_id='evt-other-runner',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-b',
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=30,
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-other-conversation',
|
||||
event_id='evt-other-conversation',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
conversation_id='conv-b',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=20,
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-allowed',
|
||||
event_id='evt-allowed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=10,
|
||||
)
|
||||
|
||||
claimed = await store.claim_next_run(
|
||||
runtime_id='runtime-a',
|
||||
queue_name='default',
|
||||
runner_ids=['runner-a'],
|
||||
conversation_id='conv-a',
|
||||
bot_id='bot-a',
|
||||
workspace_id='workspace-a',
|
||||
thread_id=None,
|
||||
strict_thread=True,
|
||||
)
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed['run_id'] == 'run-allowed'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_claim_can_be_reclaimed(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-expired',
|
||||
event_id='evt-2',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
first_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
assert first_claim is not None
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-expired')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
reclaimed = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60)
|
||||
|
||||
assert reclaimed is not None
|
||||
assert reclaimed['run_id'] == 'run-expired'
|
||||
assert reclaimed['claimed_by_runtime_id'] == 'runtime-b'
|
||||
assert reclaimed['claim_token'] != first_claim['claim_token']
|
||||
assert reclaimed['dispatch_attempts'] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_expired_claims_requeues_runs(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-expired-release',
|
||||
event_id='evt-3',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-active-claim',
|
||||
event_id='evt-4',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
expired_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
active_claim = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60)
|
||||
assert expired_claim is not None
|
||||
assert active_claim is not None
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-expired-release')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
released = await store.release_expired_claims()
|
||||
|
||||
assert [run['run_id'] for run in released] == ['run-expired-release']
|
||||
assert released[0]['status'] == 'queued'
|
||||
assert released[0]['status_reason'] == 'claim lease expired'
|
||||
assert released[0]['claimed_by_runtime_id'] is None
|
||||
assert 'claim_token' not in released[0]
|
||||
assert released[0]['claim_lease_expires_at'] is None
|
||||
|
||||
active = await store.get_run('run-active-claim')
|
||||
assert active is not None
|
||||
assert active['status'] == 'claimed'
|
||||
assert active['claimed_by_runtime_id'] == active_claim['claimed_by_runtime_id']
|
||||
assert 'claim_token' not in active
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_claim_cannot_renew_or_release(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-stale-claim',
|
||||
event_id='evt-stale',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
assert claimed is not None
|
||||
token = claimed['claim_token']
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-stale-claim')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert await store.renew_claim(run_id='run-stale-claim', claim_token=token, runtime_id='runtime-a') is None
|
||||
assert await store.release_claim(run_id='run-stale-claim', claim_token=token, runtime_id='runtime-a') is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_status_validation_and_terminal_transition_rules(store):
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.create_run(
|
||||
run_id='run-invalid-create',
|
||||
event_id='evt-invalid',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='bogus',
|
||||
)
|
||||
|
||||
await store.create_run(
|
||||
run_id='run-invalid-release',
|
||||
event_id='evt-release',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default')
|
||||
assert claim is not None
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.release_claim(
|
||||
run_id='run-invalid-release',
|
||||
claim_token=claim['claim_token'],
|
||||
runtime_id='runtime-a',
|
||||
status='bogus',
|
||||
)
|
||||
|
||||
await store.create_run(
|
||||
run_id='run-terminal',
|
||||
event_id='evt-terminal',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
)
|
||||
with pytest.raises(ValueError, match='Unknown run status'):
|
||||
await store.finalize_run(run_id='run-terminal', status='bogus')
|
||||
|
||||
completed = await store.finalize_run(
|
||||
run_id='run-terminal',
|
||||
status='completed',
|
||||
metadata={'attempt': 1},
|
||||
)
|
||||
assert completed is not None
|
||||
assert completed['status'] == 'completed'
|
||||
|
||||
merged = await store.finalize_run(
|
||||
run_id='run-terminal',
|
||||
status='completed',
|
||||
metadata={'retry_observed': True},
|
||||
)
|
||||
assert merged is not None
|
||||
assert merged['metadata'] == {'attempt': 1, 'retry_observed': True}
|
||||
|
||||
with pytest.raises(ValueError, match='Cannot transition terminal run'):
|
||||
await store.finalize_run(run_id='run-terminal', status='failed')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_audit_event_uses_next_sequence(store):
|
||||
await store.create_run(
|
||||
run_id='run-audit',
|
||||
event_id='evt-5',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
)
|
||||
await store.append_event(
|
||||
run_id='run-audit',
|
||||
sequence=1,
|
||||
event_type='message.completed',
|
||||
data={'ok': True},
|
||||
)
|
||||
|
||||
event = await store.append_audit_event(
|
||||
run_id='run-audit',
|
||||
event_type='admin.run_cancel',
|
||||
data={'action': 'run_cancel'},
|
||||
metadata={'permission': 'agent_run:admin'},
|
||||
)
|
||||
|
||||
assert event is not None
|
||||
assert event['sequence'] == 2
|
||||
assert event['type'] == 'admin.run_cancel'
|
||||
assert event['source'] == 'host'
|
||||
assert event['data'] == {'action': 'run_cancel'}
|
||||
assert event['metadata'] == {'permission': 'agent_run:admin'}
|
||||
assert await store.append_audit_event(run_id='missing', event_type='admin.missing') is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_register_heartbeat_list_and_mark_stale(store):
|
||||
registered = await store.register_runtime(
|
||||
runtime_id='runtime-a',
|
||||
display_name='Runtime A',
|
||||
endpoint='http://runtime-a',
|
||||
version='1.0.0',
|
||||
capabilities={'stream': True},
|
||||
labels={'region': 'test'},
|
||||
metadata={'slot_count': 2},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert registered['runtime_id'] == 'runtime-a'
|
||||
assert registered['status'] == 'online'
|
||||
assert registered['display_name'] == 'Runtime A'
|
||||
assert registered['capabilities'] == {'stream': True}
|
||||
assert registered['labels'] == {'region': 'test'}
|
||||
assert registered['metadata'] == {'slot_count': 2}
|
||||
assert registered['last_heartbeat_at'] is not None
|
||||
assert registered['heartbeat_deadline_at'] is not None
|
||||
|
||||
heartbeat = await store.heartbeat_runtime(
|
||||
runtime_id='runtime-a',
|
||||
metadata={'active_runs': 1},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert heartbeat is not None
|
||||
assert heartbeat['metadata'] == {'slot_count': 2, 'active_runs': 1}
|
||||
|
||||
runtimes, total_count = await store.list_runtimes(statuses=['online'])
|
||||
assert [runtime['runtime_id'] for runtime in runtimes] == ['runtime-a']
|
||||
assert total_count == 1
|
||||
|
||||
stale = await store.mark_stale_runtimes(
|
||||
now=datetime.datetime.now(UTC) + datetime.timedelta(seconds=31),
|
||||
)
|
||||
|
||||
assert [runtime['runtime_id'] for runtime in stale] == ['runtime-a']
|
||||
assert stale[0]['status'] == 'stale'
|
||||
assert (await store.get_runtime('runtime-a'))['status'] == 'stale'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_stats_splits_active_and_claimed_runs(store):
|
||||
await store.register_runtime(runtime_id='runtime-a')
|
||||
await store.create_run(
|
||||
run_id='run-running',
|
||||
event_id='evt-running',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='running',
|
||||
)
|
||||
await store.create_run(
|
||||
run_id='run-claimed',
|
||||
event_id='evt-claimed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
assert await store.claim_next_run(runtime_id='runtime-a', queue_name='default') is not None
|
||||
|
||||
stats = await store.get_runtime_stats()
|
||||
|
||||
assert stats['active_runs'] == 2
|
||||
assert stats['claimed_runs'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_stats_reports_zero_success_rate_for_failed_only_runner(store):
|
||||
now = int(datetime.datetime.now(UTC).timestamp())
|
||||
await store.create_run(
|
||||
run_id='run-failed',
|
||||
event_id='evt-failed',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='failed',
|
||||
)
|
||||
|
||||
stats = await store.get_runner_stats(start_time=now - 10, end_time=now + 10)
|
||||
|
||||
assert stats[0]['runner_id'] == 'runner-a'
|
||||
assert stats[0]['failed_runs'] == 1
|
||||
assert stats[0]['success_rate'] == 0.0
|
||||
@@ -0,0 +1,633 @@
|
||||
"""Tests for AgentRunSessionRegistry."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from langbot.pkg.agent.runner.session_registry import (
|
||||
AgentRunSessionRegistry,
|
||||
AgentRunSession,
|
||||
MAX_STEERING_QUEUE_ITEMS,
|
||||
get_session_registry,
|
||||
)
|
||||
|
||||
# Import shared test fixtures from conftest.py
|
||||
from .conftest import make_resources, make_session
|
||||
|
||||
|
||||
class TestSessionRegistryBasic:
|
||||
"""Tests for basic registry operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_get(self):
|
||||
"""Register and retrieve a session."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_abc'
|
||||
resources = make_resources(
|
||||
models=[{'model_id': 'model_001', 'model_type': 'chat', 'provider': 'openai'}],
|
||||
tools=[{'tool_name': 'web_search', 'tool_type': 'builtin'}],
|
||||
)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
result = await registry.get(run_id)
|
||||
assert result is not None
|
||||
assert result['run_id'] == run_id
|
||||
assert result['runner_id'] == 'plugin:test/my-runner/default'
|
||||
assert result['query_id'] == 1
|
||||
assert result['plugin_identity'] == 'test/my-runner'
|
||||
auth_resources = result['authorization']['resources']
|
||||
assert len(auth_resources['models']) == 1
|
||||
assert auth_resources['models'][0]['model_id'] == 'model_001'
|
||||
assert 'resources' not in result
|
||||
assert 'permissions' not in result
|
||||
assert '_authorized_ids' not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_plugin_identity(self):
|
||||
"""Agent run sessions must always have an owning plugin identity."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
with pytest.raises(ValueError, match='plugin_identity is required'):
|
||||
await registry.register(
|
||||
run_id='run_missing_identity',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='',
|
||||
resources=make_resources(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_freezes_authorization_snapshot(self):
|
||||
"""Register should freeze authorization data for the run."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[{'model_id': 'model_001'}],
|
||||
storage={'plugin_storage': True, 'workspace_storage': False},
|
||||
)
|
||||
|
||||
await registry.register(
|
||||
run_id='run_snapshot',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=resources,
|
||||
available_apis={'history_page': True},
|
||||
conversation_id='conv_001',
|
||||
)
|
||||
|
||||
resources['models'].append({'model_id': 'model_late'})
|
||||
resources['storage']['workspace_storage'] = True
|
||||
|
||||
session = await registry.get('run_snapshot')
|
||||
assert session is not None
|
||||
authorization = session['authorization']
|
||||
assert authorization['conversation_id'] == 'conv_001'
|
||||
assert authorization['available_apis'] == {'history_page': True}
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_late') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_session(self):
|
||||
"""Get should return None for nonexistent run_id."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
result = await registry.get('nonexistent_run')
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister(self):
|
||||
"""Unregister should remove session."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_xyz'
|
||||
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
)
|
||||
|
||||
# Verify registered
|
||||
result = await registry.get(run_id)
|
||||
assert result is not None
|
||||
|
||||
# Unregister
|
||||
await registry.unregister(run_id)
|
||||
|
||||
# Verify unregistered
|
||||
result = await registry.get(run_id)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_nonexistent(self):
|
||||
"""Unregister nonexistent session should not raise error."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
# Should not raise
|
||||
await registry.unregister('nonexistent_run')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_activity(self):
|
||||
"""Update activity should update last_activity_at."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
run_id = 'run_activity'
|
||||
|
||||
# Create session with manually set old timestamp
|
||||
now = int(time.time())
|
||||
old_session: AgentRunSession = make_session(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
)
|
||||
old_session['status'] = {
|
||||
'started_at': now - 100,
|
||||
'last_activity_at': now - 100,
|
||||
}
|
||||
|
||||
async with registry._lock:
|
||||
registry._sessions[run_id] = old_session
|
||||
|
||||
# Get initial session
|
||||
session1 = await registry.get(run_id)
|
||||
initial_time = session1['status']['last_activity_at']
|
||||
|
||||
# Update activity
|
||||
await registry.update_activity(run_id)
|
||||
|
||||
# Verify updated - should be significantly different (100 seconds)
|
||||
session2 = await registry.get(run_id)
|
||||
assert session2['status']['last_activity_at'] > initial_time
|
||||
assert session2['status']['last_activity_at'] - initial_time >= 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_activity_nonexistent(self):
|
||||
"""Update activity on nonexistent session should not raise."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
# Should not raise
|
||||
await registry.update_activity('nonexistent_run')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_active_runs(self):
|
||||
"""List active runs should return all sessions."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
await registry.register('run_1', 'plugin:a/b/default', 1, 'a/b', make_resources())
|
||||
await registry.register('run_2', 'plugin:c/d/default', 2, 'c/d', make_resources())
|
||||
|
||||
active_runs = await registry.list_active_runs()
|
||||
assert len(active_runs) == 2
|
||||
run_ids = [r['run_id'] for r in active_runs]
|
||||
assert 'run_1' in run_ids
|
||||
assert 'run_2' in run_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_stale_sessions(self):
|
||||
"""Cleanup should remove old sessions."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Create sessions with manually set old timestamp
|
||||
now = int(time.time())
|
||||
old_session: AgentRunSession = make_session(
|
||||
run_id='old_run',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
)
|
||||
old_session['status'] = {
|
||||
'started_at': now - 7200,
|
||||
'last_activity_at': now - 7200,
|
||||
}
|
||||
new_session: AgentRunSession = make_session(
|
||||
run_id='new_run',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=2,
|
||||
plugin_identity='test/runner',
|
||||
)
|
||||
new_session['status'] = {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
}
|
||||
|
||||
async with registry._lock:
|
||||
registry._sessions['old_run'] = old_session
|
||||
registry._sessions['new_run'] = new_session
|
||||
|
||||
# Cleanup sessions older than 1 hour
|
||||
cleaned = await registry.cleanup_stale_sessions(max_age_seconds=3600)
|
||||
assert cleaned == 1
|
||||
|
||||
# Verify old session removed, new remains
|
||||
assert await registry.get('old_run') is None
|
||||
assert await registry.get('new_run') is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_steering_all_preserves_queue_order(self):
|
||||
"""Default all-mode steering returns every queued item in FIFO order."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_1'}, 'input': {'text': 'first'}})
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_2'}, 'input': {'text': 'second'}})
|
||||
await registry.enqueue_steering('run_steering', {'event': {'event_id': 'event_3'}, 'input': {'text': 'third'}})
|
||||
|
||||
items = await registry.pull_steering('run_steering', mode='all')
|
||||
assert [item['event']['event_id'] for item in items] == ['event_1', 'event_2', 'event_3']
|
||||
assert await registry.pull_steering('run_steering', mode='all') == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_steering_one_at_a_time_leaves_remaining_items(self):
|
||||
"""one-at-a-time is an explicit runner-side throttling mode."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_one',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
await registry.enqueue_steering('run_steering_one', {'event': {'event_id': 'event_1'}})
|
||||
await registry.enqueue_steering('run_steering_one', {'event': {'event_id': 'event_2'}})
|
||||
|
||||
first = await registry.pull_steering('run_steering_one', mode='one-at-a-time')
|
||||
second = await registry.pull_steering('run_steering_one', mode='one-at-a-time')
|
||||
|
||||
assert [item['event']['event_id'] for item in first] == ['event_1']
|
||||
assert [item['event']['event_id'] for item in second] == ['event_2']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_steering_rejects_when_queue_is_full(self):
|
||||
"""A full steering queue does not claim more queries."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_full',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
for index in range(MAX_STEERING_QUEUE_ITEMS):
|
||||
assert await registry.enqueue_steering(
|
||||
'run_steering_full',
|
||||
{'event': {'event_id': f'event_{index}'}},
|
||||
)
|
||||
|
||||
assert not await registry.enqueue_steering(
|
||||
'run_steering_full',
|
||||
{'event': {'event_id': 'overflow'}},
|
||||
)
|
||||
|
||||
items = await registry.pull_steering('run_steering_full', mode='all')
|
||||
assert len(items) == MAX_STEERING_QUEUE_ITEMS
|
||||
assert all(item['event']['event_id'] != 'overflow' for item in items)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_steering_target_requires_same_scope(self):
|
||||
"""Steering claims must not cross bot/workspace/thread boundaries."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_scoped',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) == 'run_steering_scoped'
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_2',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) is None
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_2',
|
||||
) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_returns_pending_steering_queue(self):
|
||||
"""Unregister returns the removed session so callers can audit pending steering."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_unregister',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
await registry.enqueue_steering(
|
||||
'run_steering_unregister',
|
||||
{'event': {'event_id': 'event_pending'}},
|
||||
)
|
||||
|
||||
session = await registry.unregister('run_steering_unregister')
|
||||
|
||||
assert session is not None
|
||||
assert session['steering_queue'][0]['event']['event_id'] == 'event_pending'
|
||||
assert await registry.get('run_steering_unregister') is None
|
||||
|
||||
|
||||
class TestIsResourceAllowed:
|
||||
"""Tests for is_resource_allowed validation."""
|
||||
|
||||
def test_model_allowed(self):
|
||||
"""Model in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[
|
||||
{'model_id': 'model_001', 'model_type': 'chat', 'provider': 'openai'},
|
||||
{'model_id': 'model_002', 'model_type': 'embedding', 'provider': 'anthropic'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_002') is True
|
||||
|
||||
def test_model_operation_denied(self):
|
||||
"""Model resources should enforce operation-level grants."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
models=[
|
||||
{'model_id': 'model_001', 'operations': ['invoke']},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001', 'invoke') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001', 'stream') is False
|
||||
|
||||
def test_model_not_allowed(self):
|
||||
"""Model not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(models=[{'model_id': 'model_001'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_999') is False
|
||||
|
||||
def test_model_empty_resources(self):
|
||||
"""Empty models list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(models=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is False
|
||||
|
||||
def test_tool_allowed(self):
|
||||
"""Tool in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
tools=[
|
||||
{'tool_name': 'web_search', 'tool_type': 'builtin'},
|
||||
{'tool_name': 'code_exec', 'tool_type': 'plugin'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is True
|
||||
assert registry.is_resource_allowed(session, 'tool', 'code_exec') is True
|
||||
|
||||
def test_tool_operation_denied(self):
|
||||
"""Tool resources should enforce detail/call grants."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
tools=[
|
||||
{'tool_name': 'web_search', 'operations': ['detail']},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search', 'detail') is True
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search', 'call') is False
|
||||
|
||||
def test_tool_not_allowed(self):
|
||||
"""Tool not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(tools=[{'tool_name': 'web_search'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'image_gen') is False
|
||||
|
||||
def test_tool_empty_resources(self):
|
||||
"""Empty tools list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(tools=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is False
|
||||
|
||||
def test_knowledge_base_allowed(self):
|
||||
"""Knowledge base in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
knowledge_bases=[
|
||||
{'kb_id': 'kb_001', 'kb_name': 'docs', 'kb_type': 'vector'},
|
||||
{'kb_id': 'kb_002', 'kb_name': 'faq', 'kb_type': 'keyword'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is True
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_002') is True
|
||||
|
||||
def test_knowledge_base_not_allowed(self):
|
||||
"""Knowledge base not in resources should be denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(knowledge_bases=[{'kb_id': 'kb_001'}])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_999') is False
|
||||
|
||||
def test_knowledge_base_empty_resources(self):
|
||||
"""Empty knowledge bases list should deny all."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(knowledge_bases=[])
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is False
|
||||
|
||||
def test_skill_allowed(self):
|
||||
"""Skill in resources should be allowed."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(
|
||||
skills=[
|
||||
{'skill_name': 'demo', 'display_name': 'Demo'},
|
||||
{'skill_name': 'writer', 'display_name': 'Writer'},
|
||||
]
|
||||
)
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'skill', 'demo') is True
|
||||
assert registry.is_resource_allowed(session, 'skill', 'writer') is True
|
||||
assert registry.is_resource_allowed(session, 'skill', 'hidden') is False
|
||||
|
||||
def test_storage_plugin_allowed(self):
|
||||
"""Plugin storage permission should be checked."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': True, 'workspace_storage': False})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is True
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
def test_storage_workspace_allowed(self):
|
||||
"""Workspace storage permission should be checked."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': False, 'workspace_storage': True})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is True
|
||||
|
||||
def test_storage_both_denied(self):
|
||||
"""Both storage permissions denied."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
resources = make_resources(storage={'plugin_storage': False, 'workspace_storage': False})
|
||||
session = make_session(resources=resources)
|
||||
|
||||
assert registry.is_resource_allowed(session, 'storage', 'plugin') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
def test_unknown_resource_type(self):
|
||||
"""Unknown resource type should return False."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
session = make_session(resources=make_resources())
|
||||
|
||||
assert registry.is_resource_allowed(session, 'unknown_type', 'something') is False
|
||||
|
||||
def test_missing_resources_field(self):
|
||||
"""Missing resources field should not raise."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
session = make_session(resources={'models': []}) # Missing other fields
|
||||
|
||||
# Should not raise, should return False
|
||||
assert registry.is_resource_allowed(session, 'tool', 'web_search') is False
|
||||
assert registry.is_resource_allowed(session, 'knowledge_base', 'kb_001') is False
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""Tests for global registry singleton."""
|
||||
|
||||
def test_get_session_registry_returns_instance(self):
|
||||
"""get_session_registry should return AgentRunSessionRegistry."""
|
||||
# Use a separate test that doesn't modify global state
|
||||
# The singleton pattern works in production, but modifying globals
|
||||
# in tests can cause UnboundLocalError due to Python scoping
|
||||
# Instead, just verify the function signature
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
assert callable(get_session_registry)
|
||||
|
||||
# Create a fresh instance directly to verify the class works
|
||||
fresh_registry = AgentRunSessionRegistry()
|
||||
assert isinstance(fresh_registry, AgentRunSessionRegistry)
|
||||
|
||||
def test_global_registry_singleton_behavior(self):
|
||||
"""The global registry should maintain singleton behavior."""
|
||||
# Test singleton behavior without modifying global state
|
||||
# In production, calling get_session_registry() multiple times
|
||||
# returns the same instance. We verify this by checking the
|
||||
# module-level variable directly.
|
||||
from langbot.pkg.agent.runner.session_registry import _global_registry
|
||||
|
||||
# Check that the global variable exists and is either None or an instance
|
||||
global_reg = _global_registry
|
||||
if global_reg is None:
|
||||
# First call creates the instance
|
||||
registry1 = get_session_registry()
|
||||
assert isinstance(registry1, AgentRunSessionRegistry)
|
||||
# Subsequent calls return the same instance
|
||||
registry2 = get_session_registry()
|
||||
assert registry1 is registry2
|
||||
else:
|
||||
# Instance already exists, verify singleton
|
||||
registry1 = get_session_registry()
|
||||
registry2 = get_session_registry()
|
||||
assert registry1 is registry2
|
||||
assert registry1 is global_reg
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Tests for asyncio.Lock thread safety."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_register(self):
|
||||
"""Concurrent register should be safe."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Register multiple sessions concurrently
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
tasks.append(
|
||||
registry.register(
|
||||
f'run_{i}',
|
||||
'plugin:test/runner/default',
|
||||
i,
|
||||
'test/runner',
|
||||
make_resources(),
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# All sessions should be registered
|
||||
active_runs = await registry.list_active_runs()
|
||||
assert len(active_runs) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_register_and_unregister(self):
|
||||
"""Concurrent register and unregister should be safe."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
# Register
|
||||
await registry.register('run_1', 'plugin:test/runner/default', 1, 'test/runner', make_resources())
|
||||
|
||||
# Concurrent unregister and get
|
||||
tasks = [
|
||||
registry.unregister('run_1'),
|
||||
registry.get('run_1'),
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# After both complete, session should be unregistered
|
||||
result = await registry.get('run_1')
|
||||
assert result is None
|
||||
@@ -0,0 +1,544 @@
|
||||
"""Tests for State API handler authorization in RuntimeConnectionHandler.
|
||||
|
||||
Tests focus on:
|
||||
- STATE_GET authorization
|
||||
- STATE_SET authorization
|
||||
- STATE_DELETE authorization
|
||||
- STATE_LIST authorization
|
||||
|
||||
These tests instantiate real RuntimeConnectionHandler action handlers and verify:
|
||||
- Authorization errors for missing/mismatched caller_plugin_identity
|
||||
- Authorization errors for disabled state or scope
|
||||
- Full flow: set -> get -> list -> delete with real SQLite
|
||||
|
||||
Authorization rules:
|
||||
- caller_plugin_identity is REQUIRED when session has plugin_identity
|
||||
- caller_plugin_identity must match session's plugin_identity
|
||||
- enable_state must be True
|
||||
- scope must be in state_scopes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore, reset_persistent_state_store
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
# Import shared test fixtures
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""Fake connection for testing."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self, db_engine=None):
|
||||
self.logger = MagicMock()
|
||||
self.logger.debug = MagicMock()
|
||||
self.logger.warning = MagicMock()
|
||||
self.logger.error = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry():
|
||||
"""Create a fresh session registry for each test."""
|
||||
return AgentRunSessionRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(db_engine):
|
||||
"""Create a persistent state store with real SQLite."""
|
||||
reset_persistent_state_store()
|
||||
store = PersistentStateStore(db_engine)
|
||||
|
||||
# Create the table
|
||||
from langbot.pkg.entity.persistence.agent_runner_state import AgentRunnerState
|
||||
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.run_sync(AgentRunnerState.__table__.create, checkfirst=True)
|
||||
|
||||
yield store
|
||||
reset_persistent_state_store()
|
||||
|
||||
|
||||
class TestStateAPIHandlerAuthorization:
|
||||
"""Tests for State API handler authorization with real action calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_run_id_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing run_id returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Get the STATE_GET action handler (actions dict is keyed by action value string)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without run_id
|
||||
result = await state_get_handler({'scope': 'conversation', 'key': 'test_key'})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'run_id is required' in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_run_not_found_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: run_id not in session registry returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with non-existent run_id
|
||||
result = await state_get_handler({
|
||||
'run_id': 'nonexistent_run',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not found' in result.message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_caller_plugin_identity_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing caller_plugin_identity when session has plugin_identity returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session with plugin_identity
|
||||
await session_registry.register(
|
||||
run_id='run_test_missing_identity',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call without caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_missing_identity',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'caller_plugin_identity is required' in result.message
|
||||
|
||||
await session_registry.unregister('run_test_missing_identity')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_caller_identity_mismatch_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: caller_plugin_identity mismatch returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_mismatch',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Call with wrong caller_plugin_identity
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_mismatch',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'other/plugin',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'mismatch' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_mismatch')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_enable_state_false_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: enable_state=False returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': False, 'state_scopes': []},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_disabled',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_scope_not_enabled_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: scope not in state_scopes returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_scope_disabled',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key', 'actor': 'actor_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Request 'actor' scope which is not in state_scopes
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_scope_disabled',
|
||||
'scope': 'actor',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not enabled' in result.message.lower() or 'scope' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_scope_disabled')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_get_missing_scope_key_returns_error(self, session_registry, db_engine, persistent_store):
|
||||
"""STATE_GET: missing scope_key in state_context returns error."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_test_no_scope_key',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'}, # No scope_keys
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_test_no_scope_key',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not available' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_test_no_scope_key')
|
||||
|
||||
|
||||
class TestStateAPIFullFlowWithRealDB:
|
||||
"""Tests for complete State API flow with real SQLite database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_get_list_delete_flow(self, session_registry, db_engine, persistent_store):
|
||||
"""Test complete state flow: set -> get -> list -> delete with real SQLite."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register session
|
||||
await session_registry.register(
|
||||
run_id='run_full_flow',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation', 'runner']},
|
||||
state_context={
|
||||
'scope_keys': {
|
||||
'conversation': 'conv:test_runner:binding_1:conv_123',
|
||||
'runner': 'runner:test_runner:binding_1',
|
||||
},
|
||||
'binding_identity': 'binding_1',
|
||||
'conversation_id': 'conv_123',
|
||||
},
|
||||
)
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Verify session has correct state_context
|
||||
session = await session_registry.get('run_full_flow')
|
||||
assert session is not None
|
||||
state_ctx = session['authorization']['state_context']
|
||||
assert state_ctx is not None, f"state_context is None. Session keys: {list(session.keys())}"
|
||||
assert 'scope_keys' in state_ctx, f"scope_keys not in state_context: {state_ctx}"
|
||||
assert 'conversation' in state_ctx['scope_keys'], f"conversation not in scope_keys: {state_ctx['scope_keys']}"
|
||||
|
||||
# Get handlers (actions dict is keyed by action value string)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
state_list_handler = handler.actions[PluginToRuntimeAction.STATE_LIST.value]
|
||||
state_delete_handler = handler.actions[PluginToRuntimeAction.STATE_DELETE.value]
|
||||
|
||||
# 1. STATE_SET
|
||||
set_result = await state_set_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'value': {'data': 'test_value'},
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert set_result.code == 0
|
||||
assert set_result.data.get('success') is True
|
||||
|
||||
# 2. STATE_GET
|
||||
get_result = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_result.code == 0
|
||||
assert get_result.data.get('value') == {'data': 'test_value'}
|
||||
|
||||
# 3. STATE_LIST
|
||||
list_result = await state_list_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'prefix': 'external.',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert list_result.code == 0
|
||||
keys = list_result.data.get('keys', [])
|
||||
assert 'external.test_key' in keys
|
||||
|
||||
# 4. STATE_DELETE
|
||||
delete_result = await state_delete_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert delete_result.code == 0
|
||||
|
||||
# 5. Verify deleted
|
||||
get_after_delete = await state_get_handler({
|
||||
'run_id': 'run_full_flow',
|
||||
'scope': 'conversation',
|
||||
'key': 'external.test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert get_after_delete.code == 0
|
||||
assert get_after_delete.data.get('value') is None
|
||||
|
||||
await session_registry.unregister('run_full_flow')
|
||||
|
||||
|
||||
class TestStateHandlerReadsFromAuthorizationSnapshot:
|
||||
"""Tests verifying handlers read state_policy/state_context from authorization snapshot."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_policy_from_authorization(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_policy from session['authorization'], not resources."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_policy in the authorization snapshot
|
||||
await session_registry.register(
|
||||
run_id='run_policy_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': False, 'state_scopes': []},
|
||||
state_context={'scope_keys': {}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_policy
|
||||
session = await session_registry.get('run_policy_top_level')
|
||||
assert session is not None
|
||||
resources = session['authorization']['resources']
|
||||
assert 'state_policy' not in resources, "resources should NOT contain state_policy"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
# Should fail because enable_state=False in authorization.state_policy
|
||||
result = await state_get_handler({
|
||||
'run_id': 'run_policy_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code != 0
|
||||
assert 'disabled' in result.message.lower()
|
||||
|
||||
await session_registry.unregister('run_policy_top_level')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_handler_reads_state_context_from_authorization(self, session_registry, db_engine, persistent_store):
|
||||
"""Handler reads state_context from session['authorization'], not resources."""
|
||||
fake_app = FakeApplication(db_engine)
|
||||
fake_app.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
# Register with explicit state_context in the authorization snapshot
|
||||
await session_registry.register(
|
||||
run_id='run_context_top_level',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
available_apis={'state': True},
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key_xyz'}, 'binding_identity': 'binding_xyz'},
|
||||
)
|
||||
|
||||
# Verify resources does NOT contain state_context
|
||||
session = await session_registry.get('run_context_top_level')
|
||||
assert session is not None
|
||||
resources = session['authorization']['resources']
|
||||
assert 'state_context' not in resources, "resources should NOT contain state_context"
|
||||
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
|
||||
# Should use scope_key from authorization.state_context.scope_keys.conversation
|
||||
result = await state_set_handler({
|
||||
'run_id': 'run_context_top_level',
|
||||
'scope': 'conversation',
|
||||
'key': 'test_key',
|
||||
'value': 'test_value',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
# Should succeed - scope_key was found in state_context
|
||||
assert result.code == 0
|
||||
|
||||
await session_registry.unregister('run_context_top_level')
|
||||
|
||||
|
||||
class TestResourcesDoesNotContainStateMetadata:
|
||||
"""Tests verifying resources is clean - no state metadata mixed in."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resources_clean_after_register(self, session_registry):
|
||||
"""After register(), only authorization contains resources and state metadata."""
|
||||
resources = make_resources()
|
||||
|
||||
await session_registry.register(
|
||||
run_id='run_resources_clean',
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/runner',
|
||||
resources=resources,
|
||||
state_policy={'enable_state': True, 'state_scopes': ['conversation']},
|
||||
state_context={'scope_keys': {'conversation': 'conv_key'}, 'binding_identity': 'binding_1'},
|
||||
)
|
||||
|
||||
session = await session_registry.get('run_resources_clean')
|
||||
assert session is not None
|
||||
|
||||
# Verify resources is nested under authorization and is clean.
|
||||
assert 'resources' not in session
|
||||
session_resources = session['authorization']['resources']
|
||||
assert 'state_policy' not in session_resources, \
|
||||
"authorization['resources'] should NOT contain state_policy"
|
||||
assert 'state_context' not in session_resources, \
|
||||
"authorization['resources'] should NOT contain state_context"
|
||||
|
||||
assert 'state_policy' in session['authorization']
|
||||
assert 'state_context' in session['authorization']
|
||||
|
||||
await session_registry.unregister('run_resources_clean')
|
||||
@@ -0,0 +1,383 @@
|
||||
"""Tests for persistent AgentRunner state store."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import BindingScope, StatePolicy
|
||||
from langbot.pkg.agent.runner.persistent_state_store import PersistentStateStore
|
||||
from langbot.pkg.agent.runner.state_scope import (
|
||||
STATE_KEY_ALIASES,
|
||||
VALID_STATE_SCOPES,
|
||||
build_state_context,
|
||||
build_state_scope_key,
|
||||
get_binding_identity,
|
||||
normalize_state_key,
|
||||
)
|
||||
|
||||
|
||||
def make_descriptor(runner_id: str = 'plugin:test/my-runner/default') -> AgentRunnerDescriptor:
|
||||
"""Create a test descriptor."""
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='my-runner',
|
||||
runner_name='default',
|
||||
capabilities={'streaming': True},
|
||||
)
|
||||
|
||||
|
||||
class FakeActorContext:
|
||||
"""Fake actor context for event testing."""
|
||||
def __init__(self, actor_type: str = 'user', actor_id: str = 'user_123', actor_name: str = 'Test User'):
|
||||
self.actor_type = actor_type
|
||||
self.actor_id = actor_id
|
||||
self.actor_name = actor_name
|
||||
|
||||
|
||||
class FakeSubjectContext:
|
||||
"""Fake subject context for event testing."""
|
||||
def __init__(self, subject_type: str = 'message', subject_id: str = 'msg_001', data: dict | None = None):
|
||||
self.subject_type = subject_type
|
||||
self.subject_id = subject_id
|
||||
self.data = data or {}
|
||||
|
||||
|
||||
class FakeEventEnvelope:
|
||||
"""Fake event envelope for testing event-first state."""
|
||||
def __init__(
|
||||
self,
|
||||
event_id: str = 'evt_001',
|
||||
event_type: str = 'message.received',
|
||||
conversation_id: str | None = 'conv_001',
|
||||
actor: FakeActorContext | None = None,
|
||||
subject: FakeSubjectContext | None = None,
|
||||
bot_id: str = 'bot_001',
|
||||
workspace_id: str = 'ws_001',
|
||||
thread_id: str | None = None,
|
||||
):
|
||||
self.event_id = event_id
|
||||
self.event_type = event_type
|
||||
self.event_time = 1700000000
|
||||
self.source = 'platform'
|
||||
self.bot_id = bot_id
|
||||
self.workspace_id = workspace_id
|
||||
self.conversation_id = conversation_id
|
||||
self.thread_id = thread_id
|
||||
self.actor = actor or FakeActorContext()
|
||||
self.subject = subject
|
||||
self.raw_ref = None
|
||||
|
||||
|
||||
class FakeBinding:
|
||||
"""Fake binding for testing state."""
|
||||
def __init__(
|
||||
self,
|
||||
binding_id: str = 'binding_001',
|
||||
state_policy: StatePolicy | None = None,
|
||||
scope_type: str = 'agent',
|
||||
scope_id: str = 'agent_001',
|
||||
):
|
||||
self.binding_id = binding_id
|
||||
self.scope = BindingScope(scope_type=scope_type, scope_id=scope_id)
|
||||
self.state_policy = state_policy or StatePolicy()
|
||||
|
||||
|
||||
class TestStateScopeHelpers:
|
||||
"""Tests for shared state scope helpers."""
|
||||
|
||||
def test_valid_state_scopes(self):
|
||||
assert VALID_STATE_SCOPES == ('conversation', 'actor', 'subject', 'runner')
|
||||
|
||||
def test_state_key_aliases(self):
|
||||
assert STATE_KEY_ALIASES == {'conversation_id': 'external.conversation_id'}
|
||||
assert normalize_state_key('conversation_id') == 'external.conversation_id'
|
||||
assert normalize_state_key('external.session_id') == 'external.session_id'
|
||||
|
||||
def test_binding_identity_uses_binding_id_first(self):
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
assert get_binding_identity(binding) == 'binding_a'
|
||||
|
||||
def test_binding_identity_falls_back_to_scope(self):
|
||||
binding = FakeBinding(binding_id='', scope_type='workspace', scope_id='ws_001')
|
||||
assert get_binding_identity(binding) == 'workspace:ws_001'
|
||||
|
||||
def test_scope_key_building(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
subject=FakeSubjectContext(subject_id='msg_001'),
|
||||
thread_id='thread_001',
|
||||
)
|
||||
|
||||
keys = {
|
||||
scope: build_state_scope_key(scope, event, binding, descriptor)
|
||||
for scope in VALID_STATE_SCOPES
|
||||
}
|
||||
|
||||
assert keys['conversation'].startswith('conversation:v2:')
|
||||
assert keys['actor'].startswith('actor:v2:')
|
||||
assert keys['subject'].startswith('subject:v2:')
|
||||
assert keys['runner'].startswith('runner:v2:')
|
||||
assert len(set(keys.values())) == len(keys)
|
||||
|
||||
def test_scope_key_missing_identity_returns_none(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding()
|
||||
event = FakeEventEnvelope(conversation_id=None, actor=None, subject=None)
|
||||
|
||||
assert build_state_scope_key('conversation', event, binding, descriptor) is None
|
||||
assert build_state_scope_key('subject', event, binding, descriptor) is None
|
||||
assert build_state_scope_key('runner', event, binding, descriptor) is not None
|
||||
|
||||
def test_build_state_context(self):
|
||||
descriptor = make_descriptor()
|
||||
binding = FakeBinding(binding_id='binding_a')
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
subject=FakeSubjectContext(subject_id='msg_001'),
|
||||
)
|
||||
|
||||
context = build_state_context(event, binding, descriptor)
|
||||
|
||||
assert context['binding_identity'] == 'binding_a'
|
||||
assert context['conversation_id'] == 'conv_001'
|
||||
assert context['actor_id'] == 'user_001'
|
||||
assert set(context['scope_keys']) == {'conversation', 'actor', 'subject', 'runner'}
|
||||
|
||||
|
||||
class TestPersistentStateStore:
|
||||
"""Tests for persistent database-backed state store."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create a temporary async SQLite database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
os.unlink(db_path)
|
||||
|
||||
@pytest.fixture
|
||||
async def persistent_store(self, db_engine):
|
||||
"""Create a persistent state store for testing."""
|
||||
store = PersistentStateStore(db_engine)
|
||||
yield store
|
||||
await store.clear_all()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_snapshot_empty(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
|
||||
assert snapshot['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
assert snapshot['actor'] == {}
|
||||
assert snapshot['subject'] == {}
|
||||
assert snapshot['runner'] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_set_and_get(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'test_key', {'nested': 'value'}, None
|
||||
)
|
||||
assert success is True
|
||||
assert error is None
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_first_state_set_uses_upsert(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_concurrent'
|
||||
|
||||
async def set_value(value: int):
|
||||
return await persistent_store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key='external.concurrent',
|
||||
value={'value': value},
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
binding_identity='binding_001',
|
||||
scope='conversation',
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*(set_value(value) for value in range(8)))
|
||||
|
||||
assert all(success is True and error is None for success, error in results)
|
||||
stored = await persistent_store.state_get(scope_key, 'external.concurrent')
|
||||
assert stored in [{'value': value} for value in range(8)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_api_methods_normalize_public_key_aliases(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_001'
|
||||
|
||||
success, error = await persistent_store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key='conversation_id',
|
||||
value='conv_001',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
binding_identity='binding_001',
|
||||
scope='conversation',
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert await persistent_store.state_get(scope_key, 'external.conversation_id') == 'conv_001'
|
||||
assert await persistent_store.state_get(scope_key, 'conversation_id') == 'conv_001'
|
||||
|
||||
keys, _ = await persistent_store.state_list(scope_key, prefix='conversation_id')
|
||||
assert keys == ['external.conversation_id']
|
||||
|
||||
assert await persistent_store.state_delete(scope_key, 'conversation_id') is True
|
||||
assert await persistent_store.state_get(scope_key, 'external.conversation_id') is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_binding_isolation(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding_a = FakeBinding(binding_id='binding_a')
|
||||
binding_b = FakeBinding(binding_id='binding_b')
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding_a, descriptor, 'conversation', 'key', 'value_a', None
|
||||
)
|
||||
|
||||
snapshot_b = await persistent_store.build_snapshot_from_event(event, binding_b, descriptor)
|
||||
assert snapshot_b['conversation'] == {'external.conversation_id': 'conv_001'}
|
||||
|
||||
snapshot_a = await persistent_store.build_snapshot_from_event(event, binding_a, descriptor)
|
||||
assert snapshot_a['conversation']['key'] == 'value_a'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_disable_state(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding(state_policy=StatePolicy(enable_state=False))
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot == {'conversation': {}, 'actor': {}, 'subject': {}, 'runner': {}}
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
assert success is False
|
||||
assert 'disabled' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_policy_scope_restriction(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(
|
||||
conversation_id='conv_001',
|
||||
actor=FakeActorContext(actor_id='user_001'),
|
||||
)
|
||||
binding = FakeBinding(state_policy=StatePolicy(state_scopes=['conversation']))
|
||||
|
||||
success_conv, _ = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value_conv', None
|
||||
)
|
||||
assert success_conv is True
|
||||
|
||||
success_actor, error_actor = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'actor', 'key', 'value_actor', None
|
||||
)
|
||||
assert success_actor is False
|
||||
assert 'not enabled' in error_actor.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_json_size_limit(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
large_value = 'x' * (300 * 1024)
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', large_value, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'exceeds limit' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_not_json_serializable(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
success, error = await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', {'key': {1, 2, 3}}, None
|
||||
)
|
||||
assert success is False
|
||||
assert 'json' in error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_list(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.id', '123', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'external.name', 'test', None
|
||||
)
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'memory.key', 'value', None
|
||||
)
|
||||
|
||||
scope_key = build_state_scope_key('conversation', event, binding, descriptor)
|
||||
|
||||
keys, has_more = await persistent_store.state_list(scope_key)
|
||||
assert len(keys) == 3
|
||||
assert has_more is False
|
||||
|
||||
keys_ext, _ = await persistent_store.state_list(scope_key, prefix='external.')
|
||||
assert len(keys_ext) == 2
|
||||
assert 'external.id' in keys_ext
|
||||
assert 'external.name' in keys_ext
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_delete(self, persistent_store):
|
||||
descriptor = make_descriptor()
|
||||
event = FakeEventEnvelope(conversation_id='conv_001')
|
||||
binding = FakeBinding()
|
||||
|
||||
await persistent_store.apply_update_from_event(
|
||||
event, binding, descriptor, 'conversation', 'key', 'value', None
|
||||
)
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['key'] == 'value'
|
||||
|
||||
scope_key = build_state_scope_key('conversation', event, binding, descriptor)
|
||||
deleted = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted is True
|
||||
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert 'key' not in snapshot['conversation']
|
||||
|
||||
deleted_again = await persistent_store.state_delete(scope_key, 'key')
|
||||
assert deleted_again is False
|
||||
@@ -13,10 +13,13 @@ Source: src/langbot/pkg/api/http/service/model.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.default_config import AgentRunnerDefaultConfigService
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.api.http.service.model import (
|
||||
LLMModelsService,
|
||||
EmbeddingModelsService,
|
||||
@@ -29,6 +32,7 @@ from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, Reran
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
RUNNER_ID = 'plugin:test/runner/default'
|
||||
|
||||
|
||||
def _create_mock_llm_model(
|
||||
@@ -101,6 +105,22 @@ def _create_mock_result(items: list = None, first_item=None):
|
||||
return result
|
||||
|
||||
|
||||
class FakeAgentRunnerRegistry:
|
||||
async def get(self, runner_id, bound_plugins=None):
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector', 'default': {'primary': '', 'fallbacks': []}},
|
||||
],
|
||||
permissions={'models': ['invoke']},
|
||||
)
|
||||
|
||||
|
||||
class TestParseProviderApiKeys:
|
||||
"""Tests for _parse_provider_api_keys helper function."""
|
||||
|
||||
@@ -451,6 +471,52 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
assert runtime_entity.extra_args == {'temperature': 0.2}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
async def test_create_llm_model_auto_sets_schema_defined_default_pipeline_model(self):
|
||||
"""Auto-default model selection should use runner schema, not legacy field names."""
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = Mock()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
|
||||
ap.agent_runner_registry = FakeAgentRunnerRegistry()
|
||||
ap.agent_runner_default_config_service = AgentRunnerDefaultConfigService(ap)
|
||||
|
||||
pipeline = SimpleNamespace(
|
||||
uuid='pipeline-uuid',
|
||||
config={
|
||||
'ai': {
|
||||
'runner': {'id': RUNNER_ID},
|
||||
'runner_config': {
|
||||
RUNNER_ID: {
|
||||
'model': {'primary': '', 'fallbacks': []},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=pipeline))
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
model_uuid = await service.create_llm_model({
|
||||
'uuid': 'new-model-uuid',
|
||||
'name': 'New LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}, preserve_uuid=True)
|
||||
|
||||
assert model_uuid == 'new-model-uuid'
|
||||
ap.pipeline_service.update_pipeline.assert_awaited_once()
|
||||
updated_config = ap.pipeline_service.update_pipeline.await_args.args[1]['config']
|
||||
assert updated_config['ai']['runner_config'][RUNNER_ID]['model'] == {
|
||||
'primary': 'new-model-uuid',
|
||||
'fallbacks': [],
|
||||
}
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Tests for dynamic default pipeline config rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.api.http.service.pipeline import PipelineService
|
||||
|
||||
|
||||
class FakeLogger:
|
||||
def warning(self, msg):
|
||||
pass
|
||||
|
||||
|
||||
class FakeRegistry:
|
||||
def __init__(self, runners):
|
||||
self.runners = runners
|
||||
|
||||
async def list_runners(self, bound_plugins=None):
|
||||
return self.runners
|
||||
|
||||
|
||||
def make_runner(runner_id: str, config_schema: list[dict]):
|
||||
parts = runner_id.removeprefix('plugin:').split('/')
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label={'en_US': runner_id},
|
||||
plugin_author=parts[0],
|
||||
plugin_name=parts[1],
|
||||
runner_name=parts[2],
|
||||
config_schema=config_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_pipeline_config_uses_first_installed_runner_schema():
|
||||
local_agent = make_runner(
|
||||
'plugin:langbot/local-agent/default',
|
||||
[
|
||||
{'name': 'model', 'type': 'model-fallback-selector', 'default': {'primary': '', 'fallbacks': []}},
|
||||
{'name': 'prompt', 'type': 'prompt-editor', 'default': [{'role': 'system', 'content': 'Hello'}]},
|
||||
],
|
||||
)
|
||||
custom_agent = make_runner(
|
||||
'plugin:alice/custom-agent/default',
|
||||
[{'name': 'api-key', 'type': 'string', 'default': ''}],
|
||||
)
|
||||
ap = SimpleNamespace(
|
||||
logger=FakeLogger(),
|
||||
agent_runner_registry=FakeRegistry([custom_agent, local_agent]),
|
||||
)
|
||||
|
||||
config = await PipelineService(ap).get_default_pipeline_config()
|
||||
|
||||
assert config['ai']['runner']['id'] == 'plugin:alice/custom-agent/default'
|
||||
assert config['ai']['runner_config'] == {
|
||||
'plugin:alice/custom-agent/default': {
|
||||
'api-key': '',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_pipeline_config_stays_neutral_without_installed_runners():
|
||||
ap = SimpleNamespace(
|
||||
logger=FakeLogger(),
|
||||
agent_runner_registry=FakeRegistry([]),
|
||||
)
|
||||
|
||||
config = await PipelineService(ap).get_default_pipeline_config()
|
||||
|
||||
assert config['ai']['runner']['id'] == ''
|
||||
assert config['ai']['runner_config'] == {}
|
||||
@@ -181,6 +181,23 @@ def make_app(
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_box_session_id_reads_current_runner_config():
|
||||
query = make_query(101)
|
||||
query.pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:langbot/local-agent/default'},
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'box-session-id-template': 'bot-{launcher_id}-{sender_id}',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient))
|
||||
|
||||
assert service.resolve_box_session_id(query) == 'bot-test_user-test_user'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_box_service_without_explicit_client_initializes_internal_connector(monkeypatch: pytest.MonkeyPatch):
|
||||
connector = Mock()
|
||||
|
||||
@@ -27,6 +27,9 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
from langbot.pkg.pipeline import entities as pipeline_entities
|
||||
|
||||
|
||||
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class MockApplication:
|
||||
"""Mock Application object providing all basic dependencies needed by stages"""
|
||||
|
||||
@@ -202,8 +205,13 @@ def sample_query(sample_message_chain, sample_message_event, mock_adapter):
|
||||
bot_uuid='test-bot-uuid',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||
'runner': {'id': DEFAULT_RUNNER_ID},
|
||||
'runner_config': {
|
||||
DEFAULT_RUNNER_ID: {
|
||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
||||
'prompt': [{'role': 'system', 'content': 'test-prompt'}],
|
||||
},
|
||||
},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
@@ -227,8 +235,13 @@ def sample_pipeline_config():
|
||||
"""Provides sample pipeline configuration"""
|
||||
return {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||
'runner': {'id': DEFAULT_RUNNER_ID},
|
||||
'runner_config': {
|
||||
DEFAULT_RUNNER_ID: {
|
||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
||||
'prompt': [{'role': 'system', 'content': 'test-prompt'}],
|
||||
},
|
||||
},
|
||||
},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.agent.runner.errors import RunnerNotFoundError
|
||||
from langbot.pkg.pipeline.controller import Controller
|
||||
|
||||
|
||||
def make_app():
|
||||
app = SimpleNamespace()
|
||||
app.instance_config = SimpleNamespace(data={'concurrency': {'pipeline': 10}})
|
||||
app.logger = MagicMock()
|
||||
app.pipeline_mgr = SimpleNamespace()
|
||||
app.pipeline_mgr.get_pipeline_by_uuid = AsyncMock()
|
||||
app.sess_mgr = SimpleNamespace()
|
||||
app.sess_mgr.get_session = AsyncMock(return_value=SimpleNamespace())
|
||||
app.agent_run_orchestrator = SimpleNamespace()
|
||||
app.agent_run_orchestrator.try_claim_steering_from_query = AsyncMock()
|
||||
return app
|
||||
|
||||
|
||||
def make_pipeline():
|
||||
return SimpleNamespace(
|
||||
pipeline_entity=SimpleNamespace(config={'ai': {'runner': {'id': 'plugin:test/runner/default'}}}),
|
||||
bound_plugins=['test/runner'],
|
||||
bound_mcp_servers=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_claim_steering_returns_false_when_runner_lookup_fails():
|
||||
app = make_app()
|
||||
app.pipeline_mgr.get_pipeline_by_uuid.return_value = make_pipeline()
|
||||
app.agent_run_orchestrator.try_claim_steering_from_query.side_effect = RunnerNotFoundError(
|
||||
'plugin:missing/runner/default'
|
||||
)
|
||||
controller = Controller(app)
|
||||
query = SimpleNamespace(query_id=1, pipeline_uuid='pipeline-001', variables={})
|
||||
|
||||
claimed = await controller._try_claim_steering_before_session_slot(query)
|
||||
|
||||
assert claimed is False
|
||||
app.logger.warning.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_claim_steering_sets_pipeline_context_before_claiming():
|
||||
app = make_app()
|
||||
pipeline = make_pipeline()
|
||||
app.pipeline_mgr.get_pipeline_by_uuid.return_value = pipeline
|
||||
app.agent_run_orchestrator.try_claim_steering_from_query.return_value = True
|
||||
controller = Controller(app)
|
||||
query = SimpleNamespace(query_id=2, pipeline_uuid='pipeline-002', variables={})
|
||||
|
||||
claimed = await controller._try_claim_steering_before_session_slot(query)
|
||||
|
||||
assert claimed is True
|
||||
assert query.pipeline_config is pipeline.pipeline_entity.config
|
||||
assert query.variables['_pipeline_bound_plugins'] == ['test/runner']
|
||||
app.agent_run_orchestrator.try_claim_steering_from_query.assert_awaited_once_with(query)
|
||||
@@ -1,321 +0,0 @@
|
||||
"""
|
||||
Unit tests for ConversationMessageTruncator (msgtrun) pipeline stage.
|
||||
|
||||
Tests cover:
|
||||
- Normal truncation behavior based on max-round
|
||||
- Boundary length handling
|
||||
- Empty message handling
|
||||
- Multi-message chain truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
from tests.factories import (
|
||||
FakeApp,
|
||||
text_query,
|
||||
)
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
def get_msgtrun_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
# Import pipelinemgr first to trigger stage registration
|
||||
import_module('langbot.pkg.pipeline.pipelinemgr')
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.msgtrun')
|
||||
|
||||
|
||||
def get_truncator_module():
|
||||
"""Lazy import for truncator base."""
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.truncator')
|
||||
|
||||
|
||||
def get_entities_module():
|
||||
"""Lazy import for pipeline entities."""
|
||||
return import_module('langbot.pkg.pipeline.entities')
|
||||
|
||||
|
||||
def get_round_truncator_module():
|
||||
"""Lazy import for round truncator."""
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.truncators.round')
|
||||
|
||||
|
||||
def make_truncate_config(max_round: int = 5):
|
||||
"""Create a pipeline config with max-round setting."""
|
||||
return {
|
||||
'ai': {
|
||||
'local-agent': {
|
||||
'max-round': max_round,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestConversationMessageTruncatorInit:
|
||||
"""Tests for ConversationMessageTruncator initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_round_truncator(self):
|
||||
"""Initialize should select 'round' truncator by default."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
truncator = get_truncator_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
assert stage.trun is not None
|
||||
assert isinstance(stage.trun, truncator.Truncator)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_unknown_truncator_raises(self):
|
||||
"""Initialize with unknown truncator method should raise ValueError."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
truncator = get_truncator_module()
|
||||
|
||||
# Save original preregistered_truncators
|
||||
original_truncators = truncator.preregistered_truncators.copy()
|
||||
|
||||
try:
|
||||
# Clear registered truncators to simulate unknown method
|
||||
truncator.preregistered_truncators = []
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown truncator'):
|
||||
await stage.initialize(pipeline_config)
|
||||
finally:
|
||||
# Restore original truncators
|
||||
truncator.preregistered_truncators = original_truncators
|
||||
|
||||
|
||||
class TestRoundTruncatorProcess:
|
||||
"""Tests for RoundTruncator truncation behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_within_limit(self):
|
||||
"""Messages within max-round limit should not be truncated."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=5)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Create query with 3 messages (within limit)
|
||||
query = text_query('current message')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
provider_message.Message(role='assistant', content='response 1'),
|
||||
provider_message.Message(role='user', content='message 2'),
|
||||
provider_message.Message(role='assistant', content='response 2'),
|
||||
provider_message.Message(role='user', content='current message'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
# All messages should be preserved
|
||||
assert len(result.new_query.messages) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_exceeds_limit(self):
|
||||
"""Messages exceeding max-round should be truncated precisely.
|
||||
|
||||
Algorithm: traverse backwards, collect while current_round < max_round, count user messages as rounds.
|
||||
For max_round=2 with 7 messages (u1, a1, u2, a2, u3, a3, u_current):
|
||||
- Iterate: u_current(r=0<2, collect, r=1), a3(r=1<2, collect), u3(r=1<2, collect, r=2)
|
||||
- a2: r=2 not < 2 → break
|
||||
- Collected reverse: [u_current, a3, u3]
|
||||
- Reversed: [u3, a3, u_current] = 3 messages
|
||||
"""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=2) # Only keep 2 rounds
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Create query with many messages exceeding limit
|
||||
# 7 messages = 3 full rounds + 1 current user
|
||||
query = text_query('current message')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
provider_message.Message(role='assistant', content='response 1'),
|
||||
provider_message.Message(role='user', content='message 2'),
|
||||
provider_message.Message(role='assistant', content='response 2'),
|
||||
provider_message.Message(role='user', content='message 3'),
|
||||
provider_message.Message(role='assistant', content='response 3'),
|
||||
provider_message.Message(role='user', content='current message'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
# Should keep exactly 3 messages: message3, response3, current message
|
||||
messages = result.new_query.messages
|
||||
assert len(messages) == 3
|
||||
|
||||
# Verify exact message content
|
||||
assert messages[0].role == 'user'
|
||||
assert messages[0].content == 'message 3'
|
||||
assert messages[1].role == 'assistant'
|
||||
assert messages[1].content == 'response 3'
|
||||
assert messages[2].role == 'user'
|
||||
assert messages[2].content == 'current message'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_empty_messages(self):
|
||||
"""Empty messages list should return empty list."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = []
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert len(result.new_query.messages) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_single_message(self):
|
||||
"""Single message should be preserved."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='hello'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert len(result.new_query.messages) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_preserves_order(self):
|
||||
"""Truncation should preserve message order."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=2)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query('current')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='user1'),
|
||||
provider_message.Message(role='assistant', content='asst1'),
|
||||
provider_message.Message(role='user', content='user2'),
|
||||
provider_message.Message(role='assistant', content='asst2'),
|
||||
provider_message.Message(role='user', content='user3'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
messages = result.new_query.messages
|
||||
assert [(msg.role, msg.content) for msg in messages] == [
|
||||
('user', 'user2'),
|
||||
('assistant', 'asst2'),
|
||||
('user', 'user3'),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_max_round_one(self):
|
||||
"""max-round=1 should only keep last user message."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=1)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query('current')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='old1'),
|
||||
provider_message.Message(role='assistant', content='old1_resp'),
|
||||
provider_message.Message(role='user', content='current'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
messages = result.new_query.messages
|
||||
assert [(msg.role, msg.content) for msg in messages] == [('user', 'current')]
|
||||
|
||||
|
||||
class TestRoundTruncatorDirect:
|
||||
"""Direct tests for RoundTruncator class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_truncator_direct_process(self):
|
||||
"""Test RoundTruncator truncate method directly."""
|
||||
truncator_mod = get_truncator_module()
|
||||
|
||||
app = FakeApp()
|
||||
|
||||
# Get the RoundTruncator class from preregistered
|
||||
for trun_cls in truncator_mod.preregistered_truncators:
|
||||
if trun_cls.name == 'round':
|
||||
trun = trun_cls(app)
|
||||
break
|
||||
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = make_truncate_config(max_round=3)
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='m1'),
|
||||
provider_message.Message(role='assistant', content='r1'),
|
||||
provider_message.Message(role='user', content='m2'),
|
||||
provider_message.Message(role='assistant', content='r2'),
|
||||
provider_message.Message(role='user', content='hello'),
|
||||
]
|
||||
|
||||
result = await trun.truncate(query)
|
||||
|
||||
assert result is not None
|
||||
assert hasattr(result, 'messages')
|
||||
@@ -1,353 +0,0 @@
|
||||
"""
|
||||
Unit tests for N8nServiceAPIRunner._process_response
|
||||
|
||||
Tests cover four scenarios:
|
||||
- Stream adapter + n8n stream format (type:item/end)
|
||||
- Stream adapter + n8n plain JSON
|
||||
- Non-stream adapter + n8n stream format
|
||||
- Non-stream adapter + n8n plain JSON
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
# Break the circular import chain while importing n8nsvapi:
|
||||
# n8nsvapi → runner → app → pipelinemgr → all runners → runner (partially init)
|
||||
# The stubs are restored in a ``finally`` block so this module does NOT pollute
|
||||
# sys.modules for other test modules (e.g. ones importing the real
|
||||
# LocalAgentRunner, which would otherwise inherit ``object`` and break).
|
||||
# Mirrors master's intent but uses try/finally so a raised import doesn't
|
||||
# leave the global namespace in a stubbed state, and includes
|
||||
# ``langbot.pkg.utils.httpclient`` which master didn't stub.
|
||||
_runner_stub = MagicMock()
|
||||
_runner_stub.runner_class = lambda name: (lambda cls: cls) # no-op decorator
|
||||
_runner_stub.RequestRunner = object
|
||||
_import_stubs = {
|
||||
'langbot.pkg.provider.runner': _runner_stub,
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot.pkg.utils.httpclient': MagicMock(),
|
||||
}
|
||||
_saved_modules = {name: sys.modules.get(name) for name in _import_stubs}
|
||||
for _name, _stub in _import_stubs.items():
|
||||
sys.modules[_name] = _stub
|
||||
try:
|
||||
from langbot.pkg.provider.runners.n8nsvapi import N8nServiceAPIRunner
|
||||
finally:
|
||||
for _name, _original in _saved_modules.items():
|
||||
if _original is None:
|
||||
sys.modules.pop(_name, None)
|
||||
else:
|
||||
sys.modules[_name] = _original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_runner(output_key: str = 'response') -> N8nServiceAPIRunner:
|
||||
ap = Mock()
|
||||
ap.logger = Mock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'n8n-service-api': {
|
||||
'webhook-url': 'http://test-n8n/webhook',
|
||||
'output-key': output_key,
|
||||
'auth-type': 'none',
|
||||
}
|
||||
}
|
||||
}
|
||||
return N8nServiceAPIRunner(ap, pipeline_config)
|
||||
|
||||
|
||||
def make_mock_response(chunks: list[bytes | str], status: int = 200):
|
||||
"""Build a minimal aiohttp.ClientResponse mock with iter_chunked support."""
|
||||
response = Mock()
|
||||
response.status = status
|
||||
|
||||
async def iter_chunked(size):
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
response.content = Mock()
|
||||
response.content.iter_chunked = iter_chunked
|
||||
return response
|
||||
|
||||
|
||||
async def collect_chunks(runner: N8nServiceAPIRunner, chunks: list[bytes | str]):
|
||||
"""Run _process_response and collect all yielded MessageChunks."""
|
||||
response = make_mock_response(chunks)
|
||||
result = []
|
||||
async for chunk in runner._process_response(response):
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_response: stream format (type:item/end)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_single_item():
|
||||
"""Single item + end in one chunk yields final chunk with full content."""
|
||||
runner = make_runner()
|
||||
data = b'{"type":"item","content":"hello"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'hello'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_multi_item_accumulates():
|
||||
"""Multiple items accumulate into full_content."""
|
||||
runner = make_runner()
|
||||
chunks_data = [
|
||||
b'{"type":"item","content":"foo"}',
|
||||
b'{"type":"item","content":"bar"}',
|
||||
b'{"type":"end"}',
|
||||
]
|
||||
|
||||
chunks = await collect_chunks(runner, chunks_data)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'foobar'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_batches_every_8_items():
|
||||
"""Every 8th item triggers an intermediate yield before the final."""
|
||||
runner = make_runner()
|
||||
items = [f'{{"type":"item","content":"{i}"}}' for i in range(8)]
|
||||
items.append('{"type":"end"}')
|
||||
data = ''.join(items).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].is_final is False
|
||||
assert chunks[0].content == '01234567'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
assert chunks[1].is_final is True
|
||||
assert chunks[1].content == '01234567'
|
||||
assert chunks[1].msg_sequence == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_split_across_network_chunks():
|
||||
"""JSON split across multiple network chunks is reassembled correctly."""
|
||||
runner = make_runner()
|
||||
part1 = b'{"type":"item","con'
|
||||
part2 = b'tent":"world"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [part1, part2])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'world'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_no_spurious_empty_yield():
|
||||
"""chunk_idx==0 guard prevents spurious empty yield before any item is received."""
|
||||
runner = make_runner()
|
||||
# Send some non-stream JSON first, then stream
|
||||
data = b'{"type":"item","content":"x"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == 'x'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_response: plain JSON fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_with_output_key():
|
||||
"""Plain JSON with matching output_key extracts value via output_key."""
|
||||
runner = make_runner(output_key='response')
|
||||
data = json.dumps({'response': 'hello world'}).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'hello world'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_output_key_not_found():
|
||||
"""Plain JSON without output_key falls back to entire JSON string."""
|
||||
runner = make_runner(output_key='response')
|
||||
payload = {'other_key': 'hello'}
|
||||
data = json.dumps(payload).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert json.loads(chunks[0].content) == payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_output_key_empty_string():
|
||||
"""output_key present but value is empty string — returns empty string, not whole JSON."""
|
||||
runner = make_runner(output_key='response')
|
||||
data = json.dumps({'response': ''}).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_non_dict_response():
|
||||
"""Plain JSON array falls back to raw text."""
|
||||
runner = make_runner()
|
||||
data = b'["a", "b"]'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == '["a", "b"]'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_returns_raw_text():
|
||||
"""Non-JSON response returns raw text as-is."""
|
||||
runner = make_runner()
|
||||
data = b'plain text response'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'plain text response'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _call_webhook: output type depends on is_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_query(is_stream: bool):
|
||||
"""Build a minimal Query mock."""
|
||||
query = Mock()
|
||||
query.adapter = AsyncMock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=is_stream)
|
||||
|
||||
session = Mock()
|
||||
session.using_conversation = Mock()
|
||||
session.using_conversation.uuid = 'test-uuid'
|
||||
session.launcher_type = Mock()
|
||||
session.launcher_type.value = 'person'
|
||||
session.launcher_id = '12345'
|
||||
query.session = session
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = 'hi'
|
||||
query.variables = {}
|
||||
return query
|
||||
|
||||
|
||||
def make_http_session_mock(response_bytes: bytes, status: int = 200):
|
||||
"""Mock httpclient.get_session() returning a session whose post() yields response_bytes."""
|
||||
mock_response = make_mock_response([response_bytes], status=status)
|
||||
mock_response.status = status
|
||||
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_cm)
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_nonstream_adapter_plain_json():
|
||||
"""Non-stream adapter + plain JSON → single Message with output_key value."""
|
||||
runner = make_runner(output_key='response')
|
||||
query = make_query(is_stream=False)
|
||||
http_session = make_http_session_mock(json.dumps({'response': 'result text'}).encode())
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], provider_message.Message)
|
||||
assert results[0].content == 'result text'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_stream_adapter_stream_format():
|
||||
"""Stream adapter + stream format → MessageChunks, last is_final."""
|
||||
runner = make_runner()
|
||||
query = make_query(is_stream=True)
|
||||
data = b'{"type":"item","content":"hi"}{"type":"end"}'
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
|
||||
assert results[-1].is_final is True
|
||||
assert results[-1].content == 'hi'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_stream_adapter_plain_json():
|
||||
"""Stream adapter + plain JSON → single MessageChunk with is_final=True."""
|
||||
runner = make_runner(output_key='response')
|
||||
query = make_query(is_stream=True)
|
||||
data = json.dumps({'response': 'fallback'}).encode()
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
|
||||
assert results[-1].is_final is True
|
||||
assert results[-1].content == 'fallback'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_nonstream_adapter_stream_format():
|
||||
"""Non-stream adapter + stream format → single Message with accumulated content."""
|
||||
runner = make_runner()
|
||||
query = make_query(is_stream=False)
|
||||
data = b'{"type":"item","content":"foo"}{"type":"item","content":"bar"}{"type":"end"}'
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], provider_message.Message)
|
||||
assert results[0].content == 'foobar'
|
||||
@@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction, RuntimeToLangBotAction
|
||||
|
||||
|
||||
@@ -27,6 +28,22 @@ def compiled_params(statement):
|
||||
return statement.compile().params
|
||||
|
||||
|
||||
def make_agent_resources(
|
||||
models: list[dict] | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
knowledge_bases: list[dict] | None = None,
|
||||
):
|
||||
"""Create a minimal AgentRun resources payload for run-scoped action tests."""
|
||||
return {
|
||||
'models': models or [],
|
||||
'tools': tools or [],
|
||||
'knowledge_bases': knowledge_bases or [],
|
||||
'files': [],
|
||||
'storage': {'plugin_storage': False, 'workspace_storage': False},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
|
||||
class TestRagRerankAction:
|
||||
"""Tests for RAG rerank action handler."""
|
||||
|
||||
@@ -421,3 +438,433 @@ class TestHandlerQueryLookup:
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'bot_uuid': 'test-bot-uuid'}
|
||||
|
||||
|
||||
class TestAgentRunProxyActions:
|
||||
"""Tests for AgentRunner proxy actions that need host Query semantics."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.query_pool = Mock()
|
||||
mock_app.query_pool.cached_queries = {}
|
||||
mock_app.model_mgr = Mock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock()
|
||||
mock_app.model_mgr.get_rerank_model_by_uuid = AsyncMock()
|
||||
mock_app.tool_mgr = Mock()
|
||||
mock_app.tool_mgr.execute_func_call = AsyncMock(return_value={'ok': True})
|
||||
return mock_app
|
||||
|
||||
@staticmethod
|
||||
def query(remove_think=True):
|
||||
return SimpleNamespace(
|
||||
pipeline_config={'output': {'misc': {'remove-think': remove_think}}},
|
||||
variables={},
|
||||
prompt=SimpleNamespace(
|
||||
messages=[provider_message.Message(role='system', content='effective prompt')]
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_prompt_returns_query_effective_prompt(self, app):
|
||||
"""GET_PROMPT returns the preprocessed Query prompt for the active run."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_get_prompt'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[900] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=900,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(),
|
||||
available_apis={'prompt_get': True},
|
||||
)
|
||||
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_PROMPT.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['prompt'][0]['role'] == 'system'
|
||||
assert response.data['prompt'][0]['content'] == 'effective prompt'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_restores_query_and_model_options(self, app):
|
||||
"""INVOKE_LLM passes Query, model extra_args and remove-think to provider."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_options'
|
||||
query = self.query(remove_think=True)
|
||||
app.query_pool.cached_queries[901] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=901,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_001'}]),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(
|
||||
invoke_llm=AsyncMock(return_value=provider_message.Message(role='assistant', content='ok')),
|
||||
)
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(
|
||||
abilities=['func_call'],
|
||||
extra_args={'temperature': 0.2, 'top_p': 0.8},
|
||||
),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
'funcs': [{
|
||||
'name': 'search',
|
||||
'human_desc': 'Search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
}],
|
||||
'extra_args': {'temperature': 0.7, 'presence_penalty': 0.1},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
provider.invoke_llm.assert_awaited_once()
|
||||
kwargs = provider.invoke_llm.await_args.kwargs
|
||||
assert kwargs['query'] is query
|
||||
assert kwargs['extra_args'] == {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.8,
|
||||
'presence_penalty': 0.1,
|
||||
}
|
||||
assert kwargs['remove_think'] is True
|
||||
assert [tool.name for tool in kwargs['funcs']] == ['search']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_returns_provider_usage(self, app):
|
||||
"""INVOKE_LLM includes optional provider usage in the action response."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 11,
|
||||
'completion_tokens': 7,
|
||||
'total_tokens': 18,
|
||||
'prompt_tokens_details': {'cached_tokens': 3},
|
||||
}
|
||||
|
||||
class UsageProvider:
|
||||
async def invoke_llm(self, **kwargs):
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
return provider_message.Message(role='assistant', content='ok')
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[905] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=905,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=UsageProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['message']['content'] == 'ok'
|
||||
assert response.data['usage'] == usage
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_restores_query_and_options(self, app):
|
||||
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
class StreamProvider:
|
||||
def __init__(self):
|
||||
self.kwargs = None
|
||||
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
yield provider_message.MessageChunk(role='assistant', content='hi')
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_options'
|
||||
query = self.query(remove_think=False)
|
||||
app.query_pool.cached_queries[902] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=902,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_001'}]),
|
||||
)
|
||||
|
||||
provider = StreamProvider()
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={'max_tokens': 128}),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
'funcs': [{
|
||||
'name': 'search',
|
||||
'human_desc': 'Search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
}],
|
||||
'extra_args': {'max_tokens': 256},
|
||||
'remove_think': True,
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0]
|
||||
assert provider.kwargs['query'] is query
|
||||
assert provider.kwargs['extra_args'] == {'max_tokens': 256}
|
||||
assert provider.kwargs['remove_think'] is True
|
||||
assert provider.kwargs['funcs'] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_skips_none_chunks(self, app):
|
||||
"""INVOKE_LLM_STREAM tolerates provider heartbeat/no-op chunks."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
class StreamProvider:
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
yield None
|
||||
yield provider_message.MessageChunk(role='assistant', content=' done', is_final=True)
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_none_chunks'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[904] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=904,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_002'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=StreamProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_002',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_returns_provider_usage_event(self, app):
|
||||
"""INVOKE_LLM_STREAM emits a final usage-only action response when available."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 9,
|
||||
'completion_tokens': 4,
|
||||
'total_tokens': 13,
|
||||
'prompt_tokens_details': {'cached_tokens': 2},
|
||||
}
|
||||
|
||||
class StreamProvider:
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[906] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=906,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=StreamProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert responses[0].data['chunk']['content'] == 'ok'
|
||||
assert responses[1].data == {'usage': usage}
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_passes_current_query(self, app):
|
||||
"""CALL_TOOL passes the current Query back into tool execution."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_call_tool_query'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[903] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=903,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(tools=[{'tool_name': 'test/search'}]),
|
||||
)
|
||||
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.CALL_TOOL.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'tool_name': 'test/search',
|
||||
'parameters': {'q': 'langbot'},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert getattr(query, '_agent_run_session')['run_id'] == run_id
|
||||
app.tool_mgr.execute_func_call.assert_awaited_once_with(
|
||||
name='test/search',
|
||||
parameters={'q': 'langbot'},
|
||||
query=query,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_rerank_uses_authorized_model_and_extra_args(self, app):
|
||||
"""INVOKE_RERANK validates run-scoped model access and merges model extra_args."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_rerank_options'
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=904,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'rerank_001'}]),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(
|
||||
invoke_rerank=AsyncMock(return_value=[
|
||||
{'index': 0, 'relevance_score': 0.2},
|
||||
{'index': 1, 'relevance_score': 0.9},
|
||||
]),
|
||||
)
|
||||
rerank_model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(extra_args={'top_n': 5, 'return_documents': False}),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_rerank_model_by_uuid.return_value = rerank_model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'rerank_model_uuid': 'rerank_001',
|
||||
'query': 'hello',
|
||||
'documents': ['a', 'b'],
|
||||
'top_k': 1,
|
||||
'extra_args': {'top_n': 2},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['results'] == [{'index': 1, 'relevance_score': 0.9}]
|
||||
provider.invoke_rerank.assert_awaited_once()
|
||||
kwargs = provider.invoke_rerank.await_args.kwargs
|
||||
assert kwargs['extra_args'] == {'top_n': 2, 'return_documents': False}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
"""Tests for DifyServiceAPIRunner pure utility methods.
|
||||
|
||||
Tests the helper methods that don't require real Dify API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDifyExtractTextOutput:
|
||||
"""Tests for _extract_dify_text_output method."""
|
||||
|
||||
def _create_runner(self):
|
||||
"""Create runner instance."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'chat',
|
||||
'api-key': 'test-key',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
runner.dify_client = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
def test_extract_none_value(self):
|
||||
"""None returns empty string."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output(None)
|
||||
|
||||
assert result == ''
|
||||
|
||||
def test_extract_string_value(self):
|
||||
"""Plain string is returned."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('plain text')
|
||||
|
||||
assert result == 'plain text'
|
||||
|
||||
def test_extract_dict_with_content(self):
|
||||
"""Dict with 'content' key extracts content."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output({'content': 'extracted content'})
|
||||
|
||||
assert result == 'extracted content'
|
||||
|
||||
def test_extract_dict_without_content(self):
|
||||
"""Dict without 'content' key is JSON dumped."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output({'key': 'value'})
|
||||
|
||||
assert 'key' in result
|
||||
assert 'value' in result
|
||||
|
||||
def test_extract_json_string_with_content(self):
|
||||
"""JSON string with 'content' key extracts content."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('{"content": "json content"}')
|
||||
|
||||
assert result == 'json content'
|
||||
|
||||
def test_extract_json_string_without_content(self):
|
||||
"""JSON string without 'content' key returns original."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('{"other": "value"}')
|
||||
|
||||
assert '{"other": "value"}' in result
|
||||
|
||||
def test_extract_whitespace_string(self):
|
||||
"""Whitespace string returns empty."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output(' ')
|
||||
|
||||
assert result == ''
|
||||
|
||||
|
||||
class TestDifyRunnerConfigValidation:
|
||||
"""Tests for runner config validation."""
|
||||
|
||||
def test_invalid_app_type_raises(self):
|
||||
"""Invalid app-type raises DifyAPIError."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
from langbot.libs.dify_service_api.v1.errors import DifyAPIError
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'invalid-type',
|
||||
'api-key': 'test',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
with pytest.raises(DifyAPIError, match='不支持'):
|
||||
DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
|
||||
def test_valid_app_types(self):
|
||||
"""Valid app-types don't raise."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
for app_type in ['chat', 'agent', 'workflow']:
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': app_type,
|
||||
'api-key': 'test',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
# Should not raise
|
||||
assert runner is not None
|
||||
|
||||
|
||||
class TestDifyRunnerInit:
|
||||
"""Tests for runner initialization."""
|
||||
|
||||
def test_runner_stores_config(self):
|
||||
"""Runner stores pipeline_config."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'chat',
|
||||
'api-key': 'test-key',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
|
||||
assert runner.pipeline_config == pipeline_config
|
||||
assert runner.ap == mock_app
|
||||
@@ -1,281 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator
|
||||
|
||||
|
||||
class RecordingProvider:
|
||||
def __init__(self):
|
||||
self.requests: list[dict] = []
|
||||
|
||||
async def invoke_llm(self, query, model, messages, funcs, extra_args=None, remove_think=None):
|
||||
self.requests.append(
|
||||
{
|
||||
'messages': list(messages),
|
||||
'funcs': list(funcs),
|
||||
'remove_think': remove_think,
|
||||
}
|
||||
)
|
||||
|
||||
if len(self.requests) == 1:
|
||||
return provider_message.Message(
|
||||
role='assistant',
|
||||
content='Let me calculate that exactly.',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name='exec',
|
||||
arguments=json.dumps(
|
||||
{'command': ("python - <<'PY'\nnums = [1, 2, 3, 4]\nprint(sum(nums) / len(nums))\nPY")}
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
tool_result = json.loads(messages[-1].content)
|
||||
return provider_message.Message(
|
||||
role='assistant',
|
||||
content=f'The average is {tool_result["stdout"]}.',
|
||||
)
|
||||
|
||||
|
||||
class RecordingStreamProvider:
|
||||
def __init__(self):
|
||||
self.stream_requests: list[dict] = []
|
||||
|
||||
def invoke_llm_stream(self, query, model, messages, funcs, extra_args=None, remove_think=None):
|
||||
self.stream_requests.append(
|
||||
{
|
||||
'messages': list(messages),
|
||||
'funcs': list(funcs),
|
||||
'remove_think': remove_think,
|
||||
}
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
if len(self.stream_requests) == 1:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name='exec',
|
||||
arguments=json.dumps({'command': "python -c 'print(1)'"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
return
|
||||
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content='Tool execution failed.',
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
def make_query() -> pipeline_query.Query:
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
|
||||
return pipeline_query.Query.model_construct(
|
||||
query_id='avg-query',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=[],
|
||||
message_event=None,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'remove-think': False}},
|
||||
},
|
||||
prompt=SimpleNamespace(messages=[]),
|
||||
messages=[],
|
||||
user_message=provider_message.Message(
|
||||
role='user',
|
||||
content='Please calculate the average of 1, 2, 3, and 4.',
|
||||
),
|
||||
use_funcs=[SimpleNamespace(name='exec')],
|
||||
use_llm_model_uuid='test-model-uuid',
|
||||
variables={},
|
||||
)
|
||||
|
||||
|
||||
def test_stream_accumulator_merges_fragmented_tool_call_arguments():
|
||||
accumulator = _StreamAccumulator(msg_sequence=1)
|
||||
|
||||
assert (
|
||||
accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='{"command":'),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
emitted = accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert emitted is not None
|
||||
final_msg = accumulator.final_message()
|
||||
assert final_msg.tool_calls[0].function.name == 'exec'
|
||||
assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_uses_exec_for_exact_calculation():
|
||||
provider = RecordingProvider()
|
||||
model = SimpleNamespace(
|
||||
provider=provider,
|
||||
model_entity=SimpleNamespace(
|
||||
uuid='test-model-uuid',
|
||||
name='test-model',
|
||||
abilities=['func_call'],
|
||||
extra_args={},
|
||||
),
|
||||
)
|
||||
|
||||
tool_manager = SimpleNamespace(
|
||||
execute_func_call=AsyncMock(
|
||||
return_value={
|
||||
'session_id': 'avg-query',
|
||||
'backend': 'podman',
|
||||
'status': 'completed',
|
||||
'ok': True,
|
||||
'exit_code': 0,
|
||||
'stdout': '2.5',
|
||||
'stderr': '',
|
||||
'duration_ms': 18,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
app = SimpleNamespace(
|
||||
logger=Mock(),
|
||||
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
|
||||
tool_mgr=tool_manager,
|
||||
rag_mgr=SimpleNamespace(),
|
||||
box_service=SimpleNamespace(
|
||||
get_system_guidance=Mock(
|
||||
return_value=(
|
||||
'When the exec tool is available, use it for exact calculations, statistics, '
|
||||
'structured data parsing, and code execution instead of estimating mentally. '
|
||||
'Unless the user explicitly asks for the script, code, or implementation details, '
|
||||
'do not include the generated script in the final answer. '
|
||||
'A default workspace is mounted at /workspace for file tasks.'
|
||||
)
|
||||
),
|
||||
),
|
||||
skill_mgr=SimpleNamespace(
|
||||
get_skills_for_pipeline=AsyncMock(return_value=[]),
|
||||
detect_skill_activation=AsyncMock(return_value=None),
|
||||
build_activation_prompt=Mock(return_value=None),
|
||||
),
|
||||
)
|
||||
|
||||
runner = LocalAgentRunner(app, pipeline_config={})
|
||||
query = make_query()
|
||||
|
||||
results = [message async for message in runner.run(query)]
|
||||
|
||||
assert [message.role for message in results] == ['assistant', 'tool', 'assistant']
|
||||
assert results[-1].content == 'The average is 2.5.'
|
||||
|
||||
tool_manager.execute_func_call.assert_awaited_once()
|
||||
tool_name, tool_parameters = tool_manager.execute_func_call.await_args.args[:2]
|
||||
assert tool_name == 'exec'
|
||||
assert 'print(sum(nums) / len(nums))' in tool_parameters['command']
|
||||
|
||||
first_request = provider.requests[0]
|
||||
assert any(
|
||||
message.role == 'system'
|
||||
and 'exec' in str(message.content)
|
||||
and 'exact calculations' in str(message.content)
|
||||
and 'Unless the user explicitly asks for the script' in str(message.content)
|
||||
and '/workspace' in str(message.content)
|
||||
for message in first_request['messages']
|
||||
)
|
||||
assert [tool.name for tool in first_request['funcs']] == ['exec']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_streaming_tool_error_yields_message_chunks():
|
||||
provider = RecordingStreamProvider()
|
||||
model = SimpleNamespace(
|
||||
provider=provider,
|
||||
model_entity=SimpleNamespace(
|
||||
uuid='test-model-uuid',
|
||||
name='test-model',
|
||||
abilities=['func_call'],
|
||||
extra_args={},
|
||||
),
|
||||
)
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=True)
|
||||
|
||||
query = make_query()
|
||||
query.adapter = adapter
|
||||
|
||||
app = SimpleNamespace(
|
||||
logger=Mock(),
|
||||
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
|
||||
tool_mgr=SimpleNamespace(execute_func_call=AsyncMock(side_effect=RuntimeError('boom'))),
|
||||
rag_mgr=SimpleNamespace(),
|
||||
box_service=SimpleNamespace(
|
||||
get_system_guidance=Mock(return_value='sandbox guidance'),
|
||||
),
|
||||
skill_mgr=SimpleNamespace(
|
||||
get_skills_for_pipeline=AsyncMock(return_value=[]),
|
||||
detect_skill_activation=AsyncMock(return_value=None),
|
||||
build_activation_prompt=Mock(return_value=None),
|
||||
),
|
||||
)
|
||||
|
||||
runner = LocalAgentRunner(app, pipeline_config={})
|
||||
|
||||
results = [message async for message in runner.run(query)]
|
||||
|
||||
assert all(isinstance(message, provider_message.MessageChunk) for message in results)
|
||||
assert any(message.role == 'tool' and message.content == 'err: boom' for message in results)
|
||||
@@ -12,13 +12,39 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.api.http.service.model import _runtime_model_data
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.api.http.service.provider import ModelProviderService
|
||||
from langbot.pkg.entity.persistence import model as persistence_model
|
||||
from langbot.pkg.pipeline.preproc.preproc import PreProcessor
|
||||
from langbot.pkg.provider.modelmgr import requester
|
||||
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
||||
from langbot.pkg.provider.modelmgr.token import TokenManager
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
|
||||
|
||||
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
class FakeAgentRunnerRegistry:
|
||||
async def get(self, runner_id, bound_plugins=None):
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label={'en_US': 'Local Agent'},
|
||||
plugin_author='langbot',
|
||||
plugin_name='local-agent',
|
||||
runner_name='default',
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'prompt', 'type': 'prompt-editor', 'default': []},
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
|
||||
],
|
||||
capabilities={'tool_calling': True, 'knowledge_retrieval': True, 'multimodal_input': True},
|
||||
permissions={
|
||||
'models': ['invoke', 'stream'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_llm_model_data_preserves_uuid_after_update_payload_uuid_removed():
|
||||
@@ -120,6 +146,7 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
|
||||
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = Mock()
|
||||
ap.agent_runner_registry = FakeAgentRunnerRegistry()
|
||||
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
|
||||
ap.tool_mgr = SimpleNamespace(get_all_tools=AsyncMock(return_value=[]))
|
||||
ap.skill_mgr = None # PreProcessor only uses skill_mgr for the local-agent skill-binding branch
|
||||
@@ -183,11 +210,13 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
|
||||
)
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {
|
||||
'model': {'primary': model_uuid, 'fallbacks': []},
|
||||
'prompt': [],
|
||||
'knowledge-bases': [],
|
||||
'runner': {'id': DEFAULT_RUNNER_ID},
|
||||
'runner_config': {
|
||||
DEFAULT_RUNNER_ID: {
|
||||
'model': {'primary': model_uuid, 'fallbacks': []},
|
||||
'prompt': [],
|
||||
'knowledge-bases': [],
|
||||
},
|
||||
},
|
||||
},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
@@ -220,8 +249,3 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
|
||||
processed_query = result.new_query
|
||||
|
||||
assert processed_query.use_llm_model_uuid == model_uuid
|
||||
|
||||
runner = SimpleNamespace(ap=ap, pipeline_config=pipeline_config)
|
||||
candidates = await LocalAgentRunner._get_model_candidates(runner, processed_query)
|
||||
|
||||
assert [model.model_entity.uuid for model in candidates] == [model_uuid]
|
||||
|
||||
@@ -302,6 +302,59 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
|
||||
assert result.role == 'assistant'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stashes_usage(runtime_provider, runtime_llm_model):
|
||||
"""RuntimeProvider preserves requester usage for upstream action handlers."""
|
||||
provider = runtime_provider
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-query-usage',
|
||||
launcher_type='person',
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=None,
|
||||
message_event=None,
|
||||
adapter=None,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None,
|
||||
)
|
||||
usage = {
|
||||
'prompt_tokens': 11,
|
||||
'completion_tokens': 7,
|
||||
'total_tokens': 18,
|
||||
'prompt_tokens_details': {'cached_tokens': 3},
|
||||
}
|
||||
provider.requester.invoke_llm = AsyncMock(
|
||||
return_value=(
|
||||
provider_message.Message(role='assistant', content='ok'),
|
||||
usage,
|
||||
)
|
||||
)
|
||||
|
||||
result = await provider.invoke_llm(
|
||||
query,
|
||||
runtime_llm_model,
|
||||
[provider_message.Message(role='user', content='Hello')],
|
||||
)
|
||||
|
||||
assert result.content == 'ok'
|
||||
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider, runtime_llm_model):
|
||||
"""Test RuntimeProvider.invoke_llm_stream yields chunks from requester."""
|
||||
@@ -345,6 +398,61 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
|
||||
assert chunks[0].role == 'assistant'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stream_stashes_usage(runtime_provider, runtime_llm_model):
|
||||
"""RuntimeProvider transfers captured stream usage to the public query usage key."""
|
||||
provider = runtime_provider
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-stream-usage',
|
||||
launcher_type='person',
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=None,
|
||||
message_event=None,
|
||||
adapter=None,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None,
|
||||
)
|
||||
usage = {
|
||||
'prompt_tokens': 13,
|
||||
'completion_tokens': 2,
|
||||
'total_tokens': 15,
|
||||
}
|
||||
|
||||
async def fake_stream(**kwargs):
|
||||
kwargs['query'].variables[requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
|
||||
provider.requester.invoke_llm_stream = fake_stream
|
||||
|
||||
chunks = [
|
||||
chunk
|
||||
async for chunk in provider.invoke_llm_stream(
|
||||
query,
|
||||
runtime_llm_model,
|
||||
[provider_message.Message(role='user', content='Hello')],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_embedding_returns_vectors(runtime_provider, runtime_embedding_model):
|
||||
"""Test RuntimeProvider.invoke_embedding returns embedding vectors."""
|
||||
|
||||
@@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.manifest import (
|
||||
AgentRunnerCapabilities,
|
||||
AgentRunnerPermissions,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.pipeline.query import Query
|
||||
from langbot_plugin.api.entities.builtin.platform.entities import Friend
|
||||
from langbot_plugin.api.entities.builtin.platform.events import FriendMessage
|
||||
@@ -17,6 +21,32 @@ from langbot_plugin.api.entities.builtin.provider.prompt import Prompt
|
||||
from langbot_plugin.api.entities.builtin.provider.session import Conversation, LauncherTypes, Session
|
||||
|
||||
|
||||
class _FakeRunnerDescriptor:
|
||||
config_schema = [
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'prompt', 'type': 'prompt-editor', 'default': []},
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
|
||||
]
|
||||
permissions = {
|
||||
'models': ['invoke', 'stream'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
}
|
||||
permissions = AgentRunnerPermissions.model_validate(permissions)
|
||||
capabilities = AgentRunnerCapabilities(
|
||||
tool_calling=True,
|
||||
knowledge_retrieval=True,
|
||||
multimodal_input=True,
|
||||
skill_authoring=True,
|
||||
)
|
||||
|
||||
def supports_tool_calling(self):
|
||||
return self.capabilities.tool_calling
|
||||
|
||||
def supports_knowledge_retrieval(self):
|
||||
return self.capabilities.knowledge_retrieval
|
||||
|
||||
|
||||
def _make_query() -> Query:
|
||||
message_chain = MessageChain([Plain(text='create a skill')])
|
||||
return Query(
|
||||
@@ -34,11 +64,13 @@ def _make_query() -> Query:
|
||||
pipeline_uuid='pipe-1',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {
|
||||
'model': {'primary': 'model-1', 'fallbacks': []},
|
||||
'prompt': 'default',
|
||||
'knowledge-bases': [],
|
||||
'runner': {'id': 'plugin:langbot/local-agent/default'},
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': {'primary': 'model-1', 'fallbacks': []},
|
||||
'prompt': [],
|
||||
'knowledge-bases': [],
|
||||
},
|
||||
},
|
||||
},
|
||||
'trigger': {'misc': {}},
|
||||
@@ -57,6 +89,15 @@ def _make_conversation() -> Conversation:
|
||||
)
|
||||
|
||||
|
||||
async def _passthrough_preproc_event(event, bound_plugins):
|
||||
return SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
default_prompt=event.default_prompt,
|
||||
prompt=event.prompt,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_app(*, skill_service) -> SimpleNamespace:
|
||||
session = Session(launcher_type=LauncherTypes.PERSON, launcher_id='launcher-1', sender_id='sender-1')
|
||||
conversation = _make_conversation()
|
||||
@@ -83,8 +124,8 @@ def _make_app(*, skill_service) -> SimpleNamespace:
|
||||
pipeline_service=SimpleNamespace(
|
||||
get_pipeline=AsyncMock(return_value={'extensions_preferences': {'enable_all_skills': True}})
|
||||
),
|
||||
agent_runner_registry=SimpleNamespace(get=AsyncMock(return_value=_FakeRunnerDescriptor())),
|
||||
skill_mgr=SimpleNamespace(
|
||||
build_skill_aware_prompt_addition=Mock(return_value=''),
|
||||
skills={},
|
||||
),
|
||||
skill_service=skill_service,
|
||||
@@ -121,6 +162,28 @@ async def test_preproc_enables_skill_authoring_tools_when_skill_service_availabl
|
||||
app.tool_mgr.get_all_tools.assert_awaited_once_with(None, None, include_skill_authoring=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_puts_host_skill_tools_into_query_scope():
|
||||
"""AgentRunner resource authorization consumes the tools discovered by preproc."""
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
app.tool_mgr.get_all_tools = AsyncMock(
|
||||
return_value=[
|
||||
SimpleNamespace(name='activate'),
|
||||
SimpleNamespace(name='register_skill'),
|
||||
]
|
||||
)
|
||||
query = _make_query()
|
||||
stage = preproc_module.PreProcessor(app)
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
app.tool_mgr.get_all_tools.assert_awaited_once_with(None, None, include_skill_authoring=True)
|
||||
assert [tool.name for tool in query.use_funcs] == ['activate', 'register_skill']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_disables_skill_authoring_tools_when_skill_service_missing():
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
@@ -135,30 +198,24 @@ async def test_preproc_disables_skill_authoring_tools_when_skill_service_missing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_injects_skill_index_into_system_prompt():
|
||||
"""The Tool Call activation pattern still needs the LLM to know which
|
||||
skills exist. PreProcessor must append the SkillManager's index
|
||||
addendum to the first system message."""
|
||||
async def test_preproc_records_all_visible_skills_without_prompt_injection():
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
addendum = '\n\nAvailable Skills:\n- demo (demo): Demo skill.\n\nCall activate ...'
|
||||
app.skill_mgr.build_skill_aware_prompt_addition = Mock(return_value=addendum)
|
||||
|
||||
query = _make_query()
|
||||
result = await stage_process_capture(preproc_module, app, query)
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
app.skill_mgr.build_skill_aware_prompt_addition.assert_called_once_with(bound_skills=None)
|
||||
app.pipeline_service.get_pipeline.assert_awaited_once_with('pipe-1')
|
||||
assert query.variables.get('_pipeline_bound_skills') is None
|
||||
head = query.prompt.messages[0]
|
||||
assert head.role == 'system'
|
||||
assert head.content.endswith(addendum)
|
||||
assert head.content == 'system prompt'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_respects_pipeline_bound_skills_subset():
|
||||
"""When ``enable_all_skills`` is false the bound list is passed through
|
||||
so the addendum only mentions skills allowed for this pipeline."""
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
@@ -170,31 +227,78 @@ async def test_preproc_respects_pipeline_bound_skills_subset():
|
||||
}
|
||||
}
|
||||
)
|
||||
app.skill_mgr.build_skill_aware_prompt_addition = Mock(return_value='')
|
||||
|
||||
query = _make_query()
|
||||
result = await stage_process_capture(preproc_module, app, query)
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
app.skill_mgr.build_skill_aware_prompt_addition.assert_called_once_with(bound_skills=['only-this'])
|
||||
assert query.variables.get('_pipeline_bound_skills') == ['only-this']
|
||||
assert query.prompt.messages[0].content == 'system prompt'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_skips_injection_when_addendum_is_empty():
|
||||
"""No visible skills → system prompt is left untouched (no
|
||||
``Available Skills`` block appended)."""
|
||||
async def test_preproc_does_not_load_skill_preferences_without_skill_authoring_service():
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
app.skill_mgr.build_skill_aware_prompt_addition = Mock(return_value='')
|
||||
app = _make_app(skill_service=None)
|
||||
|
||||
query = _make_query()
|
||||
result = await stage_process_capture(preproc_module, app, query)
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
if query.prompt and query.prompt.messages:
|
||||
assert 'Available Skills' not in (query.prompt.messages[0].content or '')
|
||||
app.pipeline_service.get_pipeline.assert_not_awaited()
|
||||
assert '_pipeline_bound_skills' not in query.variables
|
||||
assert query.prompt.messages[0].content == 'system prompt'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_uses_transcript_history_view_when_available():
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
conversation = app.sess_mgr.get_conversation.return_value
|
||||
conversation.messages = [Message(role='user', content='legacy history')]
|
||||
app.plugin_connector.emit_event = AsyncMock(side_effect=_passthrough_preproc_event)
|
||||
|
||||
transcript_messages = [
|
||||
Message(role='user', content='from transcript user'),
|
||||
Message(role='assistant', content='from transcript assistant'),
|
||||
]
|
||||
|
||||
stage = preproc_module.PreProcessor(app)
|
||||
stage._load_agent_runner_history_messages = AsyncMock(return_value=transcript_messages)
|
||||
|
||||
query = _make_query()
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
assert query.messages == transcript_messages
|
||||
stage._load_agent_runner_history_messages.assert_awaited_once_with(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'conv-1',
|
||||
bot_id='bot-1',
|
||||
workspace_id=None,
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preproc_falls_back_to_conversation_messages_when_transcript_empty():
|
||||
preproc_module, entities_module = _import_preproc_modules()
|
||||
|
||||
app = _make_app(skill_service=SimpleNamespace())
|
||||
legacy_messages = [Message(role='user', content='legacy history')]
|
||||
app.sess_mgr.get_conversation.return_value.messages = legacy_messages
|
||||
app.plugin_connector.emit_event = AsyncMock(side_effect=_passthrough_preproc_event)
|
||||
|
||||
stage = preproc_module.PreProcessor(app)
|
||||
stage._load_agent_runner_history_messages = AsyncMock(return_value=None)
|
||||
|
||||
query = _make_query()
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
assert query.messages == legacy_messages
|
||||
|
||||
|
||||
async def stage_process_capture(preproc_module, app, query):
|
||||
|
||||
@@ -143,10 +143,6 @@ def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
|
||||
# Mock core.app - Application class is referenced but not instantiated
|
||||
mock_app = MagicMock()
|
||||
|
||||
# Mock provider.runner - has preregistered_runners attribute
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.preregistered_runners = [] # Empty by default, tests override
|
||||
|
||||
# Mock utils.importutil - prevents auto-import of runners
|
||||
mock_importutil = MagicMock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
@@ -158,19 +154,11 @@ def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
|
||||
'langbot.pkg.pipeline.controller': MagicMock(),
|
||||
'langbot.pkg.pipeline.pipelinemgr': MagicMock(),
|
||||
'langbot.pkg.pipeline.process.process': MagicMock(),
|
||||
'langbot.pkg.provider.runner': mock_runner,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
}
|
||||
|
||||
|
||||
# Package attributes that need to be updated alongside sys.modules mocking.
|
||||
# When Python imports a submodule (e.g., langbot.pkg.provider.runner), it
|
||||
# automatically sets an attribute on the parent package. The import statement
|
||||
# `from ....provider import runner` gets this attribute, not sys.modules directly.
|
||||
# This dict maps mock module names to the parent packages that need attribute updates.
|
||||
_PACKAGE_ATTRIBUTE_UPDATES: dict[str, tuple[str, str]] = {
|
||||
'langbot.pkg.provider.runner': ('langbot.pkg.provider', 'runner'),
|
||||
}
|
||||
_PACKAGE_ATTRIBUTE_UPDATES: dict[str, tuple[str, str]] = {}
|
||||
|
||||
|
||||
def get_handler_modules_to_clear(handler_name: str) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user