mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
refactor(test): consolidate FakeApp and add sys.modules isolation utility
- Extract tests/utils/import_isolation.py with isolated_sys_modules context manager - Extend tests/factories/app.py FakeApp with handler-specific attributes - Refactor test_chat_handler.py to use centralized FakeApp and cached imports - Refactor test_command_handler.py with mock_execute_factory fixture - Refactor test_smoke.py to move import-time sys.modules manipulation into fixture - Add SQLite migration integration tests (G-002) - Add HTTP API smoke integration tests (G-005) - Update CI workflow to call pytest for SQLite migrations (G-004) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
56
.github/workflows/test-migrations.yml
vendored
56
.github/workflows/test-migrations.yml
vendored
@@ -9,11 +9,13 @@ on:
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- 'src/langbot/pkg/persistence/**'
|
||||
- 'src/langbot/pkg/entity/persistence/**'
|
||||
- 'tests/integration/persistence/**'
|
||||
|
||||
jobs:
|
||||
test-migrations-sqlite:
|
||||
@@ -34,53 +36,13 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Test Alembic upgrade (SQLite)
|
||||
run: |
|
||||
uv run python -c "
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
|
||||
|
||||
async def main():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
|
||||
|
||||
# Create all tables (simulates existing DB)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(engine, '0001_baseline')
|
||||
rev = await get_alembic_current(engine)
|
||||
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
|
||||
print(f'Stamped: {rev}')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev = await get_alembic_current(engine)
|
||||
print(f'After upgrade: {rev}')
|
||||
assert rev is not None, 'Expected a revision after upgrade'
|
||||
|
||||
# Verify idempotent
|
||||
await run_alembic_upgrade(engine, 'head')
|
||||
rev2 = await get_alembic_current(engine)
|
||||
assert rev2 == rev, f'Expected {rev}, got {rev2}'
|
||||
print(f'Idempotent check passed: {rev2}')
|
||||
|
||||
# Fresh DB: upgrade from scratch
|
||||
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
|
||||
async with engine2.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await run_alembic_upgrade(engine2, 'head')
|
||||
rev3 = await get_alembic_current(engine2)
|
||||
print(f'Fresh DB upgrade: {rev3}')
|
||||
assert rev3 is not None
|
||||
|
||||
print('All SQLite migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
- name: Run SQLite migration tests
|
||||
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
|
||||
# TODO(G-003): Migrate PostgreSQL tests to pytest integration tests
|
||||
# PostgreSQL requires external database service, which will be handled in G-003.
|
||||
# The inline script below will be replaced with:
|
||||
# uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
|
||||
test-migrations-postgres:
|
||||
name: Migrations (PostgreSQL)
|
||||
runs-on: ubuntu-latest
|
||||
@@ -168,4 +130,4 @@ jobs:
|
||||
print('All PostgreSQL migration tests passed!')
|
||||
|
||||
asyncio.run(main())
|
||||
"
|
||||
"
|
||||
@@ -17,6 +17,14 @@ tests/
|
||||
│ ├── message.py # Message/query factories
|
||||
│ ├── provider.py # FakeProvider factory
|
||||
│ └── platform.py # FakePlatform factory
|
||||
├── integration/ # Integration tests (real resources)
|
||||
│ ├── __init__.py
|
||||
│ ├── api/ # HTTP API tests
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── test_smoke.py # API smoke tests
|
||||
│ └── persistence/ # Database/persistence tests
|
||||
│ ├── __init__.py
|
||||
│ └── test_migrations.py # Alembic migration tests
|
||||
├── smoke/ # Smoke tests (quick validation)
|
||||
│ └── test_fake_message_flow.py
|
||||
├── unit_tests/ # Unit tests
|
||||
@@ -28,6 +36,9 @@ tests/
|
||||
│ ├── plugin/ # Plugin system tests
|
||||
│ ├── provider/ # Provider tests
|
||||
│ └── storage/ # Storage tests
|
||||
├── utils/ # Test utilities
|
||||
│ ├── __init__.py
|
||||
│ └── import_isolation.py # sys.modules isolation for circular imports
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
@@ -147,13 +158,51 @@ uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_
|
||||
# Run only unit tests
|
||||
uv run pytest tests/unit_tests/ -m unit
|
||||
|
||||
# Run only integration tests (when available)
|
||||
uv run pytest tests/ -m integration
|
||||
# Run only integration tests
|
||||
uv run pytest tests/integration/ -m integration
|
||||
|
||||
# Run integration tests excluding slow ones
|
||||
uv run pytest tests/integration/ -m "not slow" -q
|
||||
|
||||
# Skip slow tests
|
||||
uv run pytest tests/unit_tests/ -m "not slow"
|
||||
```
|
||||
|
||||
### Running integration tests
|
||||
|
||||
Integration tests validate real system behavior with actual database/network resources.
|
||||
|
||||
```bash
|
||||
# Run all integration tests (excluding slow ones)
|
||||
uv run pytest tests/integration/ -m "not slow" -q
|
||||
|
||||
# Run SQLite migration integration tests
|
||||
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
|
||||
# Run API smoke integration tests
|
||||
uv run pytest tests/integration/api/test_smoke.py -q
|
||||
|
||||
# Run with verbose output
|
||||
uv run pytest tests/integration/ -v
|
||||
```
|
||||
|
||||
Note: Integration tests use:
|
||||
- Temporary databases (tmp_path) for persistence tests
|
||||
- Fake app/services for API tests (no real provider/platform)
|
||||
- Do not require external services
|
||||
|
||||
### Running migration tests locally
|
||||
|
||||
SQLite migration tests can be run locally without any external dependencies:
|
||||
|
||||
```bash
|
||||
# SQLite migration tests (uses tmp_path, no external DB needed)
|
||||
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
|
||||
```
|
||||
|
||||
CI workflow `.github/workflows/test-migrations.yml` runs SQLite tests using pytest.
|
||||
PostgreSQL migration tests still use inline Python script (will be migrated to pytest in G-003).
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||
@@ -250,7 +299,10 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [x] Add integration tests for database migrations (SQLite)
|
||||
- [ ] Add PostgreSQL migration integration tests (G-003)
|
||||
- [ ] Add integration tests for full pipeline execution
|
||||
- [x] Add API smoke integration tests
|
||||
- [ ] Add E2E tests
|
||||
- [ ] Add performance benchmarks
|
||||
- [ ] Add mutation testing for better coverage quality
|
||||
|
||||
@@ -18,6 +18,8 @@ class FakeApp:
|
||||
command_prefix: list[str] = ["/", "!"],
|
||||
command_enable: bool = True,
|
||||
pipeline_concurrency: int = 10,
|
||||
admins: list[str] | None = None,
|
||||
**extra_attrs,
|
||||
):
|
||||
self.logger = self._create_mock_logger()
|
||||
self.sess_mgr = self._create_mock_session_manager()
|
||||
@@ -30,9 +32,19 @@ class FakeApp:
|
||||
command_prefix=command_prefix,
|
||||
command_enable=command_enable,
|
||||
pipeline_concurrency=pipeline_concurrency,
|
||||
admins=admins or [],
|
||||
)
|
||||
self.task_mgr = self._create_mock_task_manager()
|
||||
|
||||
# Handler-specific optional attributes
|
||||
self.telemetry = self._create_mock_telemetry()
|
||||
self.survey = None
|
||||
self.cmd_mgr = self._create_mock_cmd_mgr()
|
||||
|
||||
# Apply any extra attributes for specific test scenarios
|
||||
for name, value in extra_attrs.items():
|
||||
setattr(self, name, value)
|
||||
|
||||
# Captured outbound messages (for assertions)
|
||||
self._outbound_messages: list = []
|
||||
|
||||
@@ -82,11 +94,13 @@ class FakeApp:
|
||||
command_prefix: list[str],
|
||||
command_enable: bool,
|
||||
pipeline_concurrency: int,
|
||||
admins: list[str],
|
||||
):
|
||||
instance_config = Mock()
|
||||
instance_config.data = {
|
||||
"command": {"prefix": command_prefix, "enable": command_enable},
|
||||
"concurrency": {"pipeline": pipeline_concurrency},
|
||||
"admins": admins,
|
||||
}
|
||||
return instance_config
|
||||
|
||||
@@ -95,6 +109,16 @@ class FakeApp:
|
||||
task_mgr.create_task = Mock()
|
||||
return task_mgr
|
||||
|
||||
def _create_mock_telemetry(self):
|
||||
telemetry = AsyncMock()
|
||||
telemetry.start_send_task = AsyncMock()
|
||||
return telemetry
|
||||
|
||||
def _create_mock_cmd_mgr(self):
|
||||
cmd_mgr = AsyncMock()
|
||||
cmd_mgr.execute = AsyncMock()
|
||||
return cmd_mgr
|
||||
|
||||
def capture_message(self, message):
|
||||
"""Capture an outbound message for test assertions."""
|
||||
self._outbound_messages.append(message)
|
||||
|
||||
6
tests/integration/__init__.py
Normal file
6
tests/integration/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Integration tests package.
|
||||
|
||||
These tests validate real system behavior with actual database/network resources.
|
||||
Run with: uv run pytest tests/integration/ -m "not slow" -q
|
||||
"""
|
||||
5
tests/integration/api/__init__.py
Normal file
5
tests/integration/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
API integration tests package.
|
||||
|
||||
Tests for HTTP API endpoints using Quart test client.
|
||||
"""
|
||||
347
tests/integration/api/test_smoke.py
Normal file
347
tests/integration/api/test_smoke.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
API smoke integration tests.
|
||||
|
||||
Tests real HTTP API behavior using Quart test client.
|
||||
Validates controller/service/routing wiring without real provider/platform.
|
||||
|
||||
Run: uv run pytest tests/integration/api/test_smoke.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, Mock
|
||||
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain for API controller using isolated_sys_modules.
|
||||
|
||||
Chain: http_controller → groups/plugins → core.app → pipeline entities
|
||||
|
||||
We need to mock core.app to prevent the circular chain when importing HTTPController.
|
||||
But we must allow groups to be imported to populate preregistered_groups.
|
||||
"""
|
||||
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
|
||||
|
||||
# Mock core.app with minimal Application that groups can reference
|
||||
class FakeMinimalApplication:
|
||||
pass
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.Application = FakeMinimalApplication
|
||||
|
||||
# Mock core.entities with proper Enum
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
# Modules to clear (force re-import after mocking)
|
||||
clear = [
|
||||
'langbot.pkg.api.http.controller.group',
|
||||
'langbot.pkg.api.http.controller.groups',
|
||||
'langbot.pkg.api.http.controller.groups.system',
|
||||
'langbot.pkg.api.http.controller.groups.user',
|
||||
'langbot.pkg.api.http.controller.main',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
},
|
||||
clear=clear,
|
||||
):
|
||||
# Import groups after mocking core.app/core.entities
|
||||
import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# ============== FAKE APPLICATION FOR API TESTS ==============
|
||||
|
||||
@pytest.fixture
|
||||
def fake_api_app():
|
||||
"""
|
||||
Create minimal FakeApp for API smoke tests with all required services.
|
||||
|
||||
Uses tests.factories.FakeApp as base and adds API-specific services.
|
||||
"""
|
||||
app = FakeApp()
|
||||
|
||||
# API-specific config
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'plugin': {'enable_marketplace': True},
|
||||
'space': {'url': 'https://space.langbot.app'},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
|
||||
# API-specific services
|
||||
app.user_service = Mock()
|
||||
app.user_service.is_initialized = AsyncMock(return_value=False)
|
||||
app.user_service.authenticate = AsyncMock(return_value='fake_token')
|
||||
app.user_service.create_user = AsyncMock()
|
||||
app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token'))
|
||||
app.user_service.get_user_by_email = AsyncMock(return_value=Mock())
|
||||
app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token')
|
||||
|
||||
app.apikey_service = Mock()
|
||||
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
|
||||
|
||||
app.maintenance_service = Mock()
|
||||
app.maintenance_service.get_storage_analysis = AsyncMock(return_value={})
|
||||
|
||||
app.plugin_connector.is_enable_plugin = False
|
||||
app.plugin_connector.ping_plugin_runtime = AsyncMock()
|
||||
|
||||
app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []})
|
||||
app.task_mgr.get_task_by_id = Mock(return_value=None)
|
||||
|
||||
# Required by controller groups
|
||||
app.model_mgr = Mock()
|
||||
app.platform_mgr = Mock()
|
||||
app.pipeline_pool = Mock()
|
||||
app.pipeline_mgr = Mock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ============== QUART TEST CLIENT FIXTURE ==============
|
||||
|
||||
@pytest.fixture
|
||||
async def quart_test_client(fake_api_app):
|
||||
"""
|
||||
Create Quart test client with real HTTPController and route registration.
|
||||
|
||||
Requires mock_circular_import_chain fixture to run first (usefixtures).
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_api_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
# ============== API SMOKE TESTS ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /healthz endpoint - simplest smoke test."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz_returns_ok(self, quart_test_client):
|
||||
"""
|
||||
/healthz endpoint returns {'code': 0, 'msg': 'ok'}.
|
||||
|
||||
This tests:
|
||||
- HTTPController instantiation
|
||||
- Quart app creation
|
||||
- Route registration
|
||||
- Basic response handling
|
||||
"""
|
||||
response = await quart_test_client.get('/healthz')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data == {'code': 0, 'msg': 'ok'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz_no_auth_required(self, quart_test_client):
|
||||
"""
|
||||
/healthz doesn't require authentication.
|
||||
|
||||
Tests that AuthType.NONE endpoints work without headers.
|
||||
"""
|
||||
response = await quart_test_client.get('/healthz')
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestSystemEndpoint:
|
||||
"""Tests for /api/v1/system endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_info_no_auth(self, quart_test_client):
|
||||
"""
|
||||
/api/v1/system/info returns system information without auth.
|
||||
|
||||
AuthType.NONE endpoint.
|
||||
"""
|
||||
response = await quart_test_client.get('/api/v1/system/info')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify response structure
|
||||
assert data['code'] == 0
|
||||
assert data['msg'] == 'ok'
|
||||
assert 'data' in data
|
||||
|
||||
# Verify expected fields
|
||||
system_data = data['data']
|
||||
assert 'version' in system_data
|
||||
assert 'debug' in system_data
|
||||
assert 'edition' in system_data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestProtectedEndpoints:
|
||||
"""Tests for authentication/authorization behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protected_endpoint_rejects_no_token(self, quart_test_client):
|
||||
"""
|
||||
Protected endpoint (USER_TOKEN) returns 401 without auth.
|
||||
|
||||
Tests that AuthType.USER_TOKEN properly rejects unauthorized requests.
|
||||
"""
|
||||
# /api/v1/user/check-token requires USER_TOKEN
|
||||
response = await quart_test_client.get('/api/v1/user/check-token')
|
||||
|
||||
assert response.status_code == 401
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify error response structure
|
||||
assert data['code'] == -1
|
||||
assert 'msg' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protected_endpoint_with_invalid_token(self, quart_test_client):
|
||||
"""
|
||||
Protected endpoint returns 401 with invalid token.
|
||||
"""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/user/check-token',
|
||||
headers={'Authorization': 'Bearer invalid_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestInvalidPayload:
|
||||
"""Tests for error handling with invalid payloads."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_json_body(self, quart_test_client):
|
||||
"""
|
||||
POST endpoint without JSON body handles gracefully.
|
||||
"""
|
||||
# /api/v1/user/auth expects JSON with 'user' and 'password'
|
||||
response = await quart_test_client.post('/api/v1/user/auth')
|
||||
|
||||
# Should return error (500, 400, or 401) with stable JSON structure
|
||||
assert response.status_code in (400, 500, 401)
|
||||
data = await response.get_json()
|
||||
|
||||
# Verify error response has expected structure
|
||||
assert 'code' in data
|
||||
assert 'msg' in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_structure(self, quart_test_client):
|
||||
"""
|
||||
POST with wrong JSON structure returns stable error.
|
||||
"""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/user/auth',
|
||||
json={'wrong_field': 'value'}
|
||||
)
|
||||
|
||||
# Should return error with stable JSON structure
|
||||
assert response.status_code in (400, 500, 401)
|
||||
data = await response.get_json()
|
||||
assert 'code' in data
|
||||
assert 'msg' in data
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestUserInitEndpoint:
|
||||
"""Tests for /api/v1/user/init endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_init_get_returns_not_initialized(self, quart_test_client):
|
||||
"""
|
||||
GET /api/v1/user/init returns initialized status.
|
||||
|
||||
Uses fake user_service.is_initialized() = False.
|
||||
"""
|
||||
response = await quart_test_client.get('/api/v1/user/init')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
|
||||
assert data['code'] == 0
|
||||
assert data['msg'] == 'ok'
|
||||
assert data['data']['initialized'] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestRealImports:
|
||||
"""Tests that verify real production code is imported."""
|
||||
|
||||
def test_http_controller_real_import(self):
|
||||
"""
|
||||
Verify HTTPController is real production class, not mock.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
assert HTTPController.__name__ == 'HTTPController'
|
||||
assert hasattr(HTTPController, 'initialize')
|
||||
assert hasattr(HTTPController, 'register_routes')
|
||||
|
||||
def test_group_real_import(self):
|
||||
"""
|
||||
Verify RouterGroup and AuthType are real production classes.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups
|
||||
|
||||
assert RouterGroup.__name__ == 'RouterGroup'
|
||||
assert hasattr(AuthType, 'NONE')
|
||||
assert hasattr(AuthType, 'USER_TOKEN')
|
||||
assert isinstance(preregistered_groups, list)
|
||||
|
||||
def test_system_group_registered(self):
|
||||
"""
|
||||
Verify SystemRouterGroup is registered in preregistered_groups.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import preregistered_groups
|
||||
|
||||
# Find system group
|
||||
system_group = None
|
||||
for g in preregistered_groups:
|
||||
if g.name == 'system':
|
||||
system_group = g
|
||||
break
|
||||
|
||||
assert system_group is not None
|
||||
assert system_group.path == '/api/v1/system'
|
||||
|
||||
def test_user_group_registered(self):
|
||||
"""
|
||||
Verify UserRouterGroup is registered in preregistered_groups.
|
||||
"""
|
||||
from langbot.pkg.api.http.controller.group import preregistered_groups
|
||||
|
||||
# Find user group
|
||||
user_group = None
|
||||
for g in preregistered_groups:
|
||||
if g.name == 'user':
|
||||
user_group = g
|
||||
break
|
||||
|
||||
assert user_group is not None
|
||||
assert user_group.path == '/api/v1/user'
|
||||
5
tests/integration/persistence/__init__.py
Normal file
5
tests/integration/persistence/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Persistence integration tests package.
|
||||
|
||||
Tests for database migrations and storage behavior.
|
||||
"""
|
||||
223
tests/integration/persistence/test_migrations.py
Normal file
223
tests/integration/persistence/test_migrations.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
SQLite migration integration tests.
|
||||
|
||||
Tests real Alembic migration behavior using temporary SQLite databases.
|
||||
Validates the migration workflow from .github/workflows/test-migrations.yml.
|
||||
|
||||
Run: uv run pytest tests/integration/persistence/test_migrations.py -q
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.persistence.alembic_runner import (
|
||||
run_alembic_upgrade,
|
||||
run_alembic_stamp,
|
||||
get_alembic_current,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_db_url(tmp_path):
|
||||
"""Create SQLite URL with temporary database file."""
|
||||
db_file = tmp_path / "test_migrations.db"
|
||||
return f"sqlite+aiosqlite:///{db_file}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_engine(sqlite_db_url):
|
||||
"""Create async SQLite engine."""
|
||||
engine = create_async_engine(sqlite_db_url)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestSQLiteMigrationBaseline:
|
||||
"""Tests for baseline stamp workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_stamp_sets_revision(self, sqlite_engine):
|
||||
"""
|
||||
Stamp baseline on existing tables sets correct revision.
|
||||
|
||||
Workflow:
|
||||
1. Create tables via Base.metadata.create_all
|
||||
2. Stamp with '0001_baseline'
|
||||
3. Verify current revision is '0001_baseline'
|
||||
"""
|
||||
# Create all tables (simulates existing DB created by ORM)
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_stamp_on_empty_db(self, sqlite_engine):
|
||||
"""
|
||||
Stamp on empty database (no tables) still sets revision.
|
||||
|
||||
This is an edge case - stamping without tables.
|
||||
"""
|
||||
# Don't create tables - stamp directly
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline'
|
||||
|
||||
|
||||
class TestSQLiteMigrationUpgrade:
|
||||
"""Tests for upgrade to head workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_from_baseline_to_head(self, sqlite_engine):
|
||||
"""
|
||||
Upgrade from baseline to head applies all migrations.
|
||||
|
||||
Workflow:
|
||||
1. Create tables
|
||||
2. Stamp baseline
|
||||
3. Upgrade to head
|
||||
4. Verify current revision is head
|
||||
"""
|
||||
# Create tables
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp baseline
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
# Upgrade to head
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration
|
||||
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_idempotent(self, sqlite_engine):
|
||||
"""
|
||||
Running upgrade to head multiple times is idempotent.
|
||||
|
||||
Workflow:
|
||||
1. Upgrade to head
|
||||
2. Get revision
|
||||
3. Upgrade to head again
|
||||
4. Verify same revision
|
||||
"""
|
||||
# Create tables
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Stamp and upgrade
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
rev1 = await get_alembic_current(sqlite_engine)
|
||||
|
||||
# Upgrade again - should be idempotent
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
rev2 = await get_alembic_current(sqlite_engine)
|
||||
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||
|
||||
|
||||
class TestSQLiteMigrationFreshDatabase:
|
||||
"""Tests for fresh database workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_db_upgrade_from_scratch(self, tmp_path):
|
||||
"""
|
||||
Fresh database (no tables) can be upgraded directly to head.
|
||||
|
||||
Workflow:
|
||||
1. Create fresh engine with new DB file
|
||||
2. Create tables
|
||||
3. Upgrade to head
|
||||
4. Verify revision
|
||||
"""
|
||||
# Use different DB file for fresh test
|
||||
fresh_db_file = tmp_path / "test_migrations_fresh.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Create tables on fresh DB
|
||||
async with fresh_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Upgrade to head directly (no baseline stamp)
|
||||
await run_alembic_upgrade(fresh_engine, 'head')
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(fresh_engine)
|
||||
assert rev is not None, "Expected a revision on fresh DB"
|
||||
|
||||
await fresh_engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_db_without_create_all_fails_gracefully(self, tmp_path):
|
||||
"""
|
||||
Fresh database without create_all may fail or have empty tables.
|
||||
|
||||
This tests the edge case where migrations run on truly empty DB.
|
||||
The behavior depends on migration script implementation.
|
||||
"""
|
||||
fresh_db_file = tmp_path / "test_empty_migrations.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Don't create tables - try upgrade directly
|
||||
# This may fail if migrations expect tables to exist
|
||||
try:
|
||||
await run_alembic_upgrade(fresh_engine, 'head')
|
||||
rev = await get_alembic_current(fresh_engine)
|
||||
# If it succeeds, verify revision
|
||||
assert rev is not None
|
||||
except Exception:
|
||||
# If it fails, that's acceptable behavior
|
||||
# Migrations may require create_all first
|
||||
pass
|
||||
|
||||
await fresh_engine.dispose()
|
||||
|
||||
|
||||
class TestSQLiteMigrationGetCurrent:
|
||||
"""Tests for get_alembic_current behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine):
|
||||
"""
|
||||
get_alembic_current returns None for unstamped database.
|
||||
"""
|
||||
# Create tables but don't stamp
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# No stamp - should return None
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
|
||||
"""
|
||||
get_alembic_current returns correct revision after stamp.
|
||||
"""
|
||||
async with sqlite_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline'
|
||||
@@ -1,428 +1,436 @@
|
||||
"""
|
||||
Unit tests for ChatMessageHandler behavior patterns.
|
||||
Unit tests for ChatMessageHandler - REAL imports.
|
||||
|
||||
Tests cover chat processing patterns:
|
||||
- Event emission for normal messages
|
||||
- Provider invocation pattern
|
||||
- Streaming response handling
|
||||
- Error handling
|
||||
|
||||
Uses pattern-based testing to avoid circular import issues.
|
||||
Tests the actual ChatMessageHandler class from production code.
|
||||
Uses tests.utils.import_isolation to break circular import chain safely.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from tests.factories import text_query
|
||||
from tests.factories import FakeApp
|
||||
|
||||
|
||||
class TestNormalMessageEventPattern:
|
||||
"""Tests for normal message event emission."""
|
||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||
|
||||
def test_person_event_type(self):
|
||||
"""Person messages use PersonNormalMessageReceived."""
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain using isolated_sys_modules.
|
||||
|
||||
launcher_type = LauncherTypes.PERSON
|
||||
Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr
|
||||
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if launcher_type == LauncherTypes.PERSON
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation.
|
||||
"""
|
||||
from tests.utils.import_isolation import (
|
||||
isolated_sys_modules,
|
||||
make_pipeline_handler_import_mocks,
|
||||
get_handler_modules_to_clear,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
assert event_class == events.PersonNormalMessageReceived
|
||||
mocks = make_pipeline_handler_import_mocks()
|
||||
|
||||
def test_group_event_type(self):
|
||||
"""Group messages use GroupNormalMessageReceived."""
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
# Create a default runner that yields a simple response
|
||||
class DefaultRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='fake response')
|
||||
|
||||
launcher_type = LauncherTypes.GROUP
|
||||
mocks['langbot.pkg.provider.runner'].preregistered_runners = [DefaultRunner]
|
||||
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if launcher_type == LauncherTypes.PERSON
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
clear = get_handler_modules_to_clear('chat')
|
||||
|
||||
assert event_class == events.GroupNormalMessageReceived
|
||||
|
||||
def test_event_fields_pattern(self):
|
||||
"""Normal message event has expected fields."""
|
||||
launcher_type = 'person'
|
||||
launcher_id = '12345'
|
||||
sender_id = '12345'
|
||||
text_message = 'hello world'
|
||||
|
||||
event_data = {
|
||||
'launcher_type': launcher_type,
|
||||
'launcher_id': launcher_id,
|
||||
'sender_id': sender_id,
|
||||
'text_message': text_message,
|
||||
}
|
||||
|
||||
assert event_data['text_message'] == 'hello world'
|
||||
with isolated_sys_modules(mocks=mocks, clear=clear):
|
||||
yield
|
||||
|
||||
|
||||
class TestPreventDefaultHandling:
|
||||
"""Tests for prevent_default handling in chat."""
|
||||
@pytest.fixture
|
||||
def fake_app():
|
||||
"""Create FakeApp instance."""
|
||||
return FakeApp()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_ctx():
|
||||
"""Create mock event context."""
|
||||
ctx = Mock()
|
||||
ctx.is_prevented_default = Mock(return_value=False)
|
||||
ctx.event = Mock()
|
||||
ctx.event.user_message_alter = None
|
||||
ctx.event.reply_message_chain = None
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_runner():
|
||||
"""Factory fixture to set a custom runner for tests."""
|
||||
def _set_runner(runner_class):
|
||||
import sys
|
||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
|
||||
return _set_runner
|
||||
|
||||
|
||||
# ============== CACHED LAZY IMPORTS ==============
|
||||
|
||||
_chat_handler_module = None
|
||||
_entities_module = None
|
||||
|
||||
|
||||
def get_chat_handler():
|
||||
"""Import ChatMessageHandler after circular import chain is mocked."""
|
||||
global _chat_handler_module
|
||||
if _chat_handler_module is None:
|
||||
from importlib import import_module
|
||||
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
|
||||
return _chat_handler_module
|
||||
|
||||
|
||||
def get_entities():
|
||||
"""Import pipeline entities - uses real module."""
|
||||
global _entities_module
|
||||
if _entities_module is None:
|
||||
from importlib import import_module
|
||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||
return _entities_module
|
||||
|
||||
|
||||
# ============== REAL ChatMessageHandler Tests ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestChatMessageHandlerReal:
|
||||
"""Tests for real ChatMessageHandler class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_interrupts(self):
|
||||
"""prevent_default without reply interrupts pipeline."""
|
||||
async def test_real_import_works(self):
|
||||
"""Verify we can import the real handler class."""
|
||||
chat = get_chat_handler()
|
||||
assert hasattr(chat, 'ChatMessageHandler')
|
||||
handler_cls = chat.ChatMessageHandler
|
||||
assert handler_cls.__name__ == 'ChatMessageHandler'
|
||||
|
||||
# Simulate event context
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = True
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.reply_message_chain = None
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_creation(self, fake_app):
|
||||
"""ChatMessageHandler can be instantiated."""
|
||||
chat = get_chat_handler()
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
assert handler.ap is fake_app
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx):
|
||||
"""prevent_default without reply chain yields INTERRUPT."""
|
||||
from tests.factories import text_query
|
||||
|
||||
chat = get_chat_handler()
|
||||
entities = get_entities()
|
||||
|
||||
mock_event_ctx.is_prevented_default.return_value = True
|
||||
mock_event_ctx.event.reply_message_chain = None
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
query = text_query('hello')
|
||||
query.resp_messages = []
|
||||
|
||||
should_interrupt = False
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply_message_chain is None:
|
||||
should_interrupt = True
|
||||
|
||||
assert should_interrupt is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_with_reply_continues(self):
|
||||
"""prevent_default with reply continues with that reply."""
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = True
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.reply_message_chain = text_chain('plugin reply')
|
||||
|
||||
query = text_query('hello')
|
||||
query.resp_messages = []
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_messages.append(event_ctx.event.reply_message_chain)
|
||||
|
||||
assert len(query.resp_messages) == 1
|
||||
|
||||
|
||||
class TestUserMessageAlteration:
|
||||
"""Tests for user_message alteration pattern."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_alters_message(self):
|
||||
"""User message can be altered to string."""
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = False
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.user_message_alter = 'altered text'
|
||||
|
||||
query = text_query('original')
|
||||
query.user_message = provider_message.Message(role='user', content=[])
|
||||
|
||||
# Pattern from handler
|
||||
if event_ctx.event.user_message_alter is not None:
|
||||
if isinstance(event_ctx.event.user_message_alter, str):
|
||||
query.user_message.content = [
|
||||
provider_message.ContentElement.from_text(event_ctx.event.user_message_alter)
|
||||
]
|
||||
|
||||
assert query.user_message.content[0].text == 'altered text'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_alters_message(self):
|
||||
"""User message can be altered to list."""
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
altered_list = [
|
||||
provider_message.ContentElement.from_text('part1'),
|
||||
provider_message.ContentElement.from_text('part2'),
|
||||
]
|
||||
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = False
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.user_message_alter = altered_list
|
||||
|
||||
query = text_query('original')
|
||||
query.user_message = provider_message.Message(role='user', content=[])
|
||||
|
||||
if isinstance(event_ctx.event.user_message_alter, list):
|
||||
query.user_message.content = event_ctx.event.user_message_alter
|
||||
|
||||
assert len(query.user_message.content) == 2
|
||||
|
||||
|
||||
class TestRunnerSelection:
|
||||
"""Tests for runner selection pattern."""
|
||||
|
||||
def test_runner_by_name(self):
|
||||
"""Runner is selected by name from config."""
|
||||
runner_name = 'local-agent'
|
||||
|
||||
# Simulate preregistered runners lookup - Mock with name attribute
|
||||
r1 = Mock()
|
||||
r1.name = 'local-agent'
|
||||
r2 = Mock()
|
||||
r2.name = 'dify'
|
||||
r3 = Mock()
|
||||
r3.name = 'n8n'
|
||||
preregistered_runners = [r1, r2, r3]
|
||||
|
||||
runner = None
|
||||
for r in preregistered_runners:
|
||||
if r.name == runner_name:
|
||||
runner = r
|
||||
break
|
||||
|
||||
assert runner is not None
|
||||
assert runner.name == 'local-agent'
|
||||
|
||||
def test_unknown_runner_raises(self):
|
||||
"""Unknown runner name raises error."""
|
||||
runner_name = 'unknown-runner'
|
||||
preregistered_runners = [
|
||||
Mock(name='local-agent'),
|
||||
Mock(name='dify'),
|
||||
]
|
||||
|
||||
runner = None
|
||||
for r in preregistered_runners:
|
||||
if r.name == runner_name:
|
||||
runner = r
|
||||
break
|
||||
|
||||
if runner is None:
|
||||
error_raised = True
|
||||
|
||||
assert error_raised is True
|
||||
|
||||
|
||||
class TestStreamingResponse:
|
||||
"""Tests for streaming response pattern."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chunks_pattern(self):
|
||||
"""Streaming produces multiple chunks."""
|
||||
chunks = ['Hello', ' World', '!']
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
# Simulate streaming generator
|
||||
async def stream_gen():
|
||||
for chunk in chunks:
|
||||
results.append(chunk)
|
||||
|
||||
await stream_gen()
|
||||
|
||||
assert len(results) == 3
|
||||
assert ''.join(results) == 'Hello World!'
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_resp_message_id(self):
|
||||
"""Streaming uses uuid for resp_message_id."""
|
||||
resp_message_id = str(uuid.uuid4())
|
||||
async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx):
|
||||
"""prevent_default with reply yields CONTINUE and updates resp_messages."""
|
||||
from tests.factories import text_query, text_chain
|
||||
|
||||
assert len(resp_message_id) == 36 # UUID format
|
||||
chat = get_chat_handler()
|
||||
entities = get_entities()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_pop_previous(self):
|
||||
"""Streaming pops previous response before adding new."""
|
||||
query = text_query('test')
|
||||
query.resp_messages = [Mock()] # Previous chunk
|
||||
query.resp_message_chain = [Mock()]
|
||||
reply_chain = text_chain('plugin reply')
|
||||
mock_event_ctx.is_prevented_default.return_value = True
|
||||
mock_event_ctx.event.reply_message_chain = reply_chain
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Pattern from handler: pop before adding new chunk
|
||||
if query.resp_messages:
|
||||
query.resp_messages.pop()
|
||||
if query.resp_message_chain:
|
||||
query.resp_message_chain.pop()
|
||||
|
||||
query.resp_messages.append(Mock()) # New chunk
|
||||
|
||||
assert len(query.resp_messages) == 1 # Only new chunk
|
||||
|
||||
|
||||
class TestNonStreamingResponse:
|
||||
"""Tests for non-streaming response pattern."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_response_pattern(self):
|
||||
"""Non-streaming produces single response."""
|
||||
query = text_query('test')
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
query = text_query('hello')
|
||||
query.resp_messages = []
|
||||
|
||||
# Simulate non-streaming runner
|
||||
async def run():
|
||||
yield Mock(readable_str=lambda: 'response text')
|
||||
|
||||
async for result in run():
|
||||
query.resp_messages.append(result)
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert len(query.resp_messages) == 1
|
||||
|
||||
|
||||
class TestExceptionHandling:
|
||||
"""Tests for exception handling pattern."""
|
||||
assert query.resp_messages[0] == reply_chain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_interrupts(self):
|
||||
"""Exception produces INTERRUPT result."""
|
||||
async def test_user_message_alter_string(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""user_message_alter as string updates query.user_message."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
text_query('test')
|
||||
pipeline_config = {
|
||||
'output': {
|
||||
'misc': {
|
||||
'exception-handling': 'show-hint',
|
||||
'failure-hint': 'Request failed.',
|
||||
}
|
||||
}
|
||||
}
|
||||
chat = get_chat_handler()
|
||||
|
||||
# Simulate exception
|
||||
exception = ValueError('provider error')
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
mock_event_ctx.event.user_message_alter = 'altered text'
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
|
||||
query = text_query('original')
|
||||
query.adapter = Mock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
query.user_message = Message(role='user', content=[])
|
||||
|
||||
if exception_handling == 'show-error':
|
||||
user_notice = f'{exception}'
|
||||
elif exception_handling == 'show-hint':
|
||||
user_notice = pipeline_config['output']['misc'].get('failure-hint', 'Request failed.')
|
||||
else: # hide
|
||||
user_notice = None
|
||||
class QuickRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='ok')
|
||||
|
||||
assert user_notice == 'Request failed.'
|
||||
set_runner(QuickRunner)
|
||||
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert query.user_message.content is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_show_error(self):
|
||||
"""show-error mode shows actual error."""
|
||||
text_query('test')
|
||||
pipeline_config = {
|
||||
'output': {
|
||||
'misc': {
|
||||
'exception-handling': 'show-error',
|
||||
}
|
||||
}
|
||||
}
|
||||
async def test_adapter_without_stream_method_defaults_non_stream(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""Adapter without is_stream_output_supported defaults to non-stream."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement
|
||||
|
||||
exception = ValueError('API timeout')
|
||||
chat = get_chat_handler()
|
||||
|
||||
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
mock_event_ctx.event.user_message_alter = None
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
if exception_handling == 'show-error':
|
||||
user_notice = f'{exception}'
|
||||
else:
|
||||
user_notice = 'Request failed.'
|
||||
|
||||
assert user_notice == 'API timeout'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_hide(self):
|
||||
"""hide mode shows no user notice."""
|
||||
text_query('test')
|
||||
pipeline_config = {
|
||||
'output': {
|
||||
'misc': {
|
||||
'exception-handling': 'hide',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ValueError('hidden error')
|
||||
|
||||
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
|
||||
|
||||
if exception_handling == 'hide':
|
||||
user_notice = None
|
||||
else:
|
||||
user_notice = 'Error'
|
||||
|
||||
assert user_notice is None
|
||||
|
||||
|
||||
class TestMessageHistoryUpdate:
|
||||
"""Tests for conversation message history."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_appended_to_conversation(self):
|
||||
"""User message and response appended to conversation."""
|
||||
query = text_query('test')
|
||||
query.session = Mock()
|
||||
query.session.using_conversation = Mock()
|
||||
query.session.using_conversation.messages = []
|
||||
query.adapter = Mock(spec=[])
|
||||
query.user_message = Message(role='user', content=[ContentElement.from_text('test')])
|
||||
|
||||
query.user_message = Mock()
|
||||
query.resp_messages = [Mock(), Mock()]
|
||||
class SingleRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='response')
|
||||
|
||||
# Pattern from handler after successful response
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
set_runner(SingleRunner)
|
||||
|
||||
assert len(query.session.using_conversation.messages) == 3
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
class TestStreamOutputCheck:
|
||||
"""Tests for stream output support check."""
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestChatHandlerStreaming:
|
||||
"""Tests for streaming behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_stream_check(self):
|
||||
"""Adapter is checked for stream support."""
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=True)
|
||||
async def test_streaming_chunks_collected(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""Streaming produces multiple results."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement, MessageChunk
|
||||
|
||||
is_stream = await adapter.is_stream_output_supported()
|
||||
chat = get_chat_handler()
|
||||
|
||||
assert is_stream is True
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
query = text_query('stream test')
|
||||
query.adapter = Mock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=True)
|
||||
query.adapter.create_message_card = AsyncMock()
|
||||
query.user_message = Message(role='user', content=[ContentElement.from_text('test')])
|
||||
|
||||
class StreamRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
yield MessageChunk(role='assistant', content='Hello', is_final=False)
|
||||
yield MessageChunk(role='assistant', content=' World', is_final=True)
|
||||
|
||||
set_runner(StreamRunner)
|
||||
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestChatHandlerExceptions:
|
||||
"""Tests for exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_no_stream_method(self):
|
||||
"""Adapter without method defaults to False."""
|
||||
adapter = Mock(spec=[]) # Empty spec, no methods
|
||||
# No is_stream_output_supported method
|
||||
async def test_runner_exception_yields_interrupt(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""Runner exception yields INTERRUPT with error notices."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
is_stream = False
|
||||
try:
|
||||
if hasattr(adapter, 'is_stream_output_supported'):
|
||||
is_stream = await adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
chat = get_chat_handler()
|
||||
entities = get_entities()
|
||||
|
||||
assert is_stream is False
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
query = text_query('fail test')
|
||||
query.adapter = Mock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
query.user_message = Message(role='user', content=[])
|
||||
|
||||
class TestTelemetryPattern:
|
||||
"""Tests for telemetry reporting pattern."""
|
||||
|
||||
def test_telemetry_payload_fields(self):
|
||||
"""Telemetry payload has expected fields."""
|
||||
query_id = 123
|
||||
adapter_name = 'TestAdapter'
|
||||
runner_name = 'local-agent'
|
||||
duration_ms = 150
|
||||
|
||||
payload = {
|
||||
'query_id': query_id,
|
||||
'adapter': adapter_name,
|
||||
'runner': runner_name,
|
||||
'duration_ms': duration_ms,
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
}
|
||||
|
||||
assert payload['query_id'] == 123
|
||||
assert payload['duration_ms'] == 150
|
||||
class FailingRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
raise ValueError('API error')
|
||||
yield
|
||||
|
||||
def test_telemetry_error_included(self):
|
||||
"""Telemetry includes error info on failure."""
|
||||
error_info = 'Traceback...'
|
||||
set_runner(FailingRunner)
|
||||
|
||||
payload = {
|
||||
'error': error_info,
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].user_notice == 'Request failed.'
|
||||
assert results[0].error_notice is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_show_error_mode(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""show-error mode shows actual exception."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
chat = get_chat_handler()
|
||||
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
query = text_query('error test')
|
||||
query.adapter = Mock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
query.user_message = Message(role='user', content=[])
|
||||
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'show-error'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
}
|
||||
|
||||
assert payload['error'] == 'Traceback...'
|
||||
class ErrorRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
raise ValueError('Custom error')
|
||||
yield
|
||||
|
||||
set_runner(ErrorRunner)
|
||||
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert results[0].user_notice == 'Custom error'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_hide_mode(self, fake_app, mock_event_ctx, set_runner):
|
||||
"""hide mode shows no user notice."""
|
||||
from tests.factories import text_query
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
chat = get_chat_handler()
|
||||
|
||||
mock_event_ctx.is_prevented_default.return_value = False
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
query = text_query('hide test')
|
||||
query.adapter = Mock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
query.user_message = Message(role='user', content=[])
|
||||
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'hide'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
}
|
||||
|
||||
class HideErrorRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
raise RuntimeError('hidden')
|
||||
yield
|
||||
|
||||
set_runner(HideErrorRunner)
|
||||
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert results[0].user_notice is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestChatHandlerHelper:
|
||||
"""Tests for helper methods."""
|
||||
|
||||
def test_cut_str_short(self, fake_app):
|
||||
"""cut_str returns short string unchanged."""
|
||||
chat = get_chat_handler()
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
result = handler.cut_str('short text')
|
||||
assert result == 'short text'
|
||||
|
||||
def test_cut_str_long(self, fake_app):
|
||||
"""cut_str truncates long string."""
|
||||
chat = get_chat_handler()
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
result = handler.cut_str('this is a very long string that exceeds twenty characters')
|
||||
assert '...' in result
|
||||
assert len(result) <= 23
|
||||
|
||||
def test_cut_str_multiline(self, fake_app):
|
||||
"""cut_str truncates multiline string."""
|
||||
chat = get_chat_handler()
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
result = handler.cut_str('first line\nsecond line')
|
||||
assert '...' in result
|
||||
@@ -1,308 +1,396 @@
|
||||
"""
|
||||
Unit tests for CommandHandler behavior patterns.
|
||||
Unit tests for CommandHandler - REAL imports.
|
||||
|
||||
Tests cover command processing patterns:
|
||||
- Command parsing and routing
|
||||
- Event emission pattern
|
||||
- Command manager interaction
|
||||
- Privilege handling
|
||||
|
||||
Uses pattern-based testing to avoid circular import issues in source code.
|
||||
Tests the actual CommandHandler class from production code.
|
||||
Uses tests.utils.import_isolation to break circular import chain safely.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from tests.factories import command_query
|
||||
from tests.factories import FakeApp, command_query
|
||||
|
||||
|
||||
class TestCommandParsingPattern:
|
||||
"""Tests for command parsing logic."""
|
||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||
|
||||
def test_command_text_extraction(self):
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain using isolated_sys_modules.
|
||||
|
||||
Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr
|
||||
|
||||
Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation.
|
||||
"""
|
||||
from tests.utils.import_isolation import (
|
||||
isolated_sys_modules,
|
||||
make_pipeline_handler_import_mocks,
|
||||
get_handler_modules_to_clear,
|
||||
)
|
||||
|
||||
mocks = make_pipeline_handler_import_mocks()
|
||||
clear = get_handler_modules_to_clear('command')
|
||||
|
||||
with isolated_sys_modules(mocks=mocks, clear=clear):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_app():
|
||||
"""Create FakeApp instance."""
|
||||
return FakeApp()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_ctx():
|
||||
"""Create mock event context."""
|
||||
ctx = Mock()
|
||||
ctx.is_prevented_default = Mock(return_value=False)
|
||||
ctx.event = Mock()
|
||||
ctx.event.reply_message_chain = None
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execute_factory():
|
||||
"""Factory fixture to create mock cmd_mgr.execute generators."""
|
||||
def _create_execute(
|
||||
text: str | None = 'ok',
|
||||
error: str | None = None,
|
||||
image_url: str | None = None,
|
||||
image_base64: str | None = None,
|
||||
file_url: str | None = None,
|
||||
):
|
||||
async def mock_execute(command_text, full_command_text, query, session):
|
||||
ret = Mock()
|
||||
ret.text = text
|
||||
ret.error = error
|
||||
ret.image_url = image_url
|
||||
ret.image_base64 = image_base64
|
||||
ret.file_url = file_url
|
||||
yield ret
|
||||
return mock_execute
|
||||
return _create_execute
|
||||
|
||||
|
||||
# ============== CACHED LAZY IMPORTS ==============
|
||||
|
||||
_command_handler_module = None
|
||||
_entities_module = None
|
||||
|
||||
|
||||
def get_command_handler():
|
||||
"""Import CommandHandler after circular import chain is mocked."""
|
||||
global _command_handler_module
|
||||
if _command_handler_module is None:
|
||||
from importlib import import_module
|
||||
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
|
||||
return _command_handler_module
|
||||
|
||||
|
||||
def get_entities():
|
||||
"""Import pipeline entities - uses real module."""
|
||||
global _entities_module
|
||||
if _entities_module is None:
|
||||
from importlib import import_module
|
||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||
return _entities_module
|
||||
|
||||
|
||||
# ============== REAL CommandHandler Tests ==============
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestCommandHandlerReal:
|
||||
"""Tests for real CommandHandler class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_import_works(self):
|
||||
"""Verify we can import the real handler class."""
|
||||
command = get_command_handler()
|
||||
assert hasattr(command, 'CommandHandler')
|
||||
handler_cls = command.CommandHandler
|
||||
assert handler_cls.__name__ == 'CommandHandler'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_creation(self, fake_app):
|
||||
"""CommandHandler can be instantiated."""
|
||||
command = get_command_handler()
|
||||
handler = command.CommandHandler(fake_app)
|
||||
assert handler.ap is fake_app
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_parsing_extracts_command_name(self, fake_app, mock_event_ctx):
|
||||
"""Command text is extracted after prefix."""
|
||||
# Simulate the parsing pattern from command handler
|
||||
full_command_text = "/help arg1 arg2"
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
# Handler strips first character (prefix)
|
||||
command_text = full_command_text.strip()[1:]
|
||||
parts = command_text.split(' ')
|
||||
executed_commands = []
|
||||
async def track_execute(command_text, full_command_text, query, session):
|
||||
executed_commands.append(command_text)
|
||||
ret = Mock()
|
||||
ret.text = 'ok'
|
||||
ret.error = None
|
||||
ret.image_url = None
|
||||
ret.image_base64 = None
|
||||
ret.file_url = None
|
||||
yield ret
|
||||
|
||||
assert parts[0] == 'help'
|
||||
assert parts[1:] == ['arg1', 'arg2']
|
||||
fake_app.cmd_mgr.execute = track_execute
|
||||
|
||||
def test_empty_command_parts(self):
|
||||
"""Empty command has no parts."""
|
||||
full_command_text = "/"
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('help arg1 arg2')
|
||||
|
||||
command_text = full_command_text.strip()[1:]
|
||||
parts = command_text.split(' ')
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert parts == ['']
|
||||
assert executed_commands[0] == 'help arg1 arg2'
|
||||
|
||||
def test_single_command_no_args(self):
|
||||
"""Single command has no arguments."""
|
||||
full_command_text = "/status"
|
||||
|
||||
command_text = full_command_text.strip()[1:]
|
||||
parts = command_text.split(' ')
|
||||
|
||||
assert parts == ['status']
|
||||
|
||||
|
||||
class TestCommandEventCreation:
|
||||
"""Tests for command event creation pattern."""
|
||||
|
||||
def test_event_type_by_launcher_type(self):
|
||||
"""Event type differs for person/group."""
|
||||
import langbot_plugin.api.entities.events as events
|
||||
|
||||
# Person command
|
||||
person_event_class = events.PersonCommandSent
|
||||
|
||||
# Group command
|
||||
group_event_class = events.GroupCommandSent
|
||||
|
||||
assert person_event_class is not None
|
||||
assert group_event_class is not None
|
||||
|
||||
def test_event_fields_pattern(self):
|
||||
"""Command event should have expected fields."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Admin users get privilege level 2."""
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
|
||||
launcher_type = LauncherTypes.PERSON.value
|
||||
launcher_id = '12345'
|
||||
sender_id = '12345'
|
||||
command = 'help'
|
||||
params = ['arg1', 'arg2']
|
||||
is_admin = False
|
||||
command = get_command_handler()
|
||||
|
||||
# Simulate event creation pattern
|
||||
event_data = {
|
||||
'launcher_type': launcher_type,
|
||||
'launcher_id': launcher_id,
|
||||
'sender_id': sender_id,
|
||||
'command': command,
|
||||
'params': params,
|
||||
'is_admin': is_admin,
|
||||
}
|
||||
fake_app.instance_config.data = {'admins': ['person_12345']}
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory()
|
||||
|
||||
assert event_data['command'] == 'help'
|
||||
assert event_data['params'] == ['arg1', 'arg2']
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('status')
|
||||
query.launcher_type = LauncherTypes.PERSON
|
||||
query.launcher_id = 12345
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
class TestPrivilegeCheckPattern:
|
||||
"""Tests for privilege/admin check."""
|
||||
|
||||
def test_admin_check_by_session_id(self):
|
||||
"""Admin is checked by session_id format."""
|
||||
admins = ['person_12345', 'group_99999']
|
||||
launcher_type = 'person'
|
||||
launcher_id = '12345'
|
||||
|
||||
session_id = f'{launcher_type}_{launcher_id}'
|
||||
is_admin = session_id in admins
|
||||
|
||||
assert is_admin is True
|
||||
|
||||
def test_non_admin_check(self):
|
||||
"""Non-admin user has privilege 1."""
|
||||
admins = ['person_12345']
|
||||
launcher_type = 'person'
|
||||
launcher_id = '67890'
|
||||
|
||||
session_id = f'{launcher_type}_{launcher_id}'
|
||||
is_admin = session_id in admins
|
||||
|
||||
assert is_admin is False
|
||||
|
||||
def test_privilege_levels(self):
|
||||
"""Privilege level 1 for normal, 2 for admin."""
|
||||
normal_privilege = 1
|
||||
admin_privilege = 2
|
||||
|
||||
admins = ['person_12345']
|
||||
|
||||
# Normal user
|
||||
session_id = 'person_67890'
|
||||
privilege = 2 if session_id in admins else 1
|
||||
assert privilege == normal_privilege
|
||||
|
||||
# Admin user
|
||||
session_id = 'person_12345'
|
||||
privilege = 2 if session_id in admins else 1
|
||||
assert privilege == admin_privilege
|
||||
|
||||
|
||||
class TestCommandResultHandling:
|
||||
"""Tests for command result handling patterns."""
|
||||
call_args = fake_app.plugin_connector.emit_event.call_args
|
||||
event = call_args[0][0]
|
||||
assert event.is_admin is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_result_pattern(self):
|
||||
"""Text result is converted to message."""
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
async def test_non_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Non-admin users get privilege level 1."""
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
|
||||
# Simulate command return
|
||||
ret = Mock()
|
||||
ret.text = 'Command output'
|
||||
ret.error = None
|
||||
ret.image_url = None
|
||||
ret.image_base64 = None
|
||||
ret.file_url = None
|
||||
command = get_command_handler()
|
||||
|
||||
# Pattern from handler: build content list
|
||||
content = []
|
||||
if ret.text is not None:
|
||||
content.append(provider_message.ContentElement.from_text(ret.text))
|
||||
fake_app.instance_config.data = {'admins': ['person_12345']}
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory()
|
||||
|
||||
assert len(content) == 1
|
||||
assert content[0].type == 'text'
|
||||
assert content[0].text == 'Command output'
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('status')
|
||||
query.launcher_type = LauncherTypes.PERSON
|
||||
query.launcher_id = 67890
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
call_args = fake_app.plugin_connector.emit_event.call_args
|
||||
event = call_args[0][0]
|
||||
assert event.is_admin is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_result_pattern(self):
|
||||
async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx):
|
||||
"""prevent_default with reply yields CONTINUE."""
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
command = get_command_handler()
|
||||
entities = get_entities()
|
||||
|
||||
reply_chain = text_chain('plugin reply')
|
||||
mock_event_ctx.is_prevented_default.return_value = True
|
||||
mock_event_ctx.event.reply_message_chain = reply_chain
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('test')
|
||||
query.resp_messages = []
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert len(query.resp_messages) == 1
|
||||
assert query.resp_messages[0] == reply_chain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx):
|
||||
"""prevent_default without reply yields INTERRUPT."""
|
||||
command = get_command_handler()
|
||||
entities = get_entities()
|
||||
|
||||
mock_event_ctx.is_prevented_default.return_value = True
|
||||
mock_event_ctx.event.reply_message_chain = None
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('test')
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_type_person_command(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Person launcher creates PersonCommandSent event."""
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
from langbot_plugin.api.entities import events
|
||||
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory()
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('help')
|
||||
query.launcher_type = LauncherTypes.PERSON
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
call_args = fake_app.plugin_connector.emit_event.call_args
|
||||
event = call_args[0][0]
|
||||
assert isinstance(event, events.PersonCommandSent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_type_group_command(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Group launcher creates GroupCommandSent event."""
|
||||
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
|
||||
from langbot_plugin.api.entities import events
|
||||
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory()
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('help')
|
||||
query.launcher_type = LauncherTypes.GROUP
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
call_args = fake_app.plugin_connector.emit_event.call_args
|
||||
event = call_args[0][0]
|
||||
assert isinstance(event, events.GroupCommandSent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_result_text(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Text result is added to resp_messages."""
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory(text='Command output')
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('echo')
|
||||
query.resp_messages = []
|
||||
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(query.resp_messages) == 1
|
||||
msg = query.resp_messages[0]
|
||||
assert msg.role == 'command'
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].type == 'text'
|
||||
assert msg.content[0].text == 'Command output'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_result_error(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Error result creates error message."""
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory(text=None, error='Command failed')
|
||||
|
||||
ret = Mock()
|
||||
ret.text = None
|
||||
ret.error = 'Command failed'
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('fail')
|
||||
query.resp_messages = []
|
||||
|
||||
# Error handling pattern
|
||||
if ret.error is not None:
|
||||
msg = provider_message.Message(
|
||||
role='command',
|
||||
content=str(ret.error),
|
||||
)
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(query.resp_messages) == 1
|
||||
msg = query.resp_messages[0]
|
||||
assert msg.role == 'command'
|
||||
assert msg.content == 'Command failed'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_result_pattern(self):
|
||||
"""Image result is added to content."""
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
async def test_command_result_image_url(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Image URL result is added to content."""
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory(
|
||||
text='Here is the image:',
|
||||
image_url='https://example.com/image.png'
|
||||
)
|
||||
|
||||
ret = Mock()
|
||||
ret.text = 'Here is the image:'
|
||||
ret.error = None
|
||||
ret.image_url = 'https://example.com/image.png'
|
||||
ret.image_base64 = None
|
||||
ret.file_url = None
|
||||
|
||||
content = []
|
||||
if ret.text is not None:
|
||||
content.append(provider_message.ContentElement.from_text(ret.text))
|
||||
if ret.image_url is not None:
|
||||
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
|
||||
|
||||
assert len(content) == 2
|
||||
assert content[0].type == 'text'
|
||||
assert content[1].type == 'image_url'
|
||||
|
||||
|
||||
class TestPreventDefaultHandling:
|
||||
"""Tests for prevent_default handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_with_reply(self):
|
||||
"""prevent_default with reply continues pipeline."""
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
# Simulate event context
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = True
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.reply_message_chain = text_chain('plugin reply')
|
||||
|
||||
query = command_query('test')
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('image')
|
||||
query.resp_messages = []
|
||||
|
||||
# Pattern from handler
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply_message_chain is not None:
|
||||
query.resp_messages.append(event_ctx.event.reply_message_chain)
|
||||
# yield CONTINUE
|
||||
else:
|
||||
# yield INTERRUPT
|
||||
pass
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert len(query.resp_messages) == 1
|
||||
msg = query.resp_messages[0]
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].type == 'text'
|
||||
assert msg.content[1].type == 'image_url'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prevent_default_without_reply(self):
|
||||
"""prevent_default without reply interrupts."""
|
||||
event_ctx = Mock()
|
||||
event_ctx.is_prevented_default.return_value = True
|
||||
event_ctx.event = Mock()
|
||||
event_ctx.event.reply_message_chain = None
|
||||
async def test_command_result_empty_interrupts(self, fake_app, mock_event_ctx, mock_execute_factory):
|
||||
"""Empty result yields INTERRUPT."""
|
||||
command = get_command_handler()
|
||||
entities = get_entities()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory(text=None)
|
||||
|
||||
query = command_query('test')
|
||||
query.resp_messages = []
|
||||
handler = command.CommandHandler(fake_app)
|
||||
query = command_query('empty')
|
||||
|
||||
should_interrupt = False
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply_message_chain is None:
|
||||
should_interrupt = True
|
||||
results = []
|
||||
async for result in handler.handle(query):
|
||||
results.append(result)
|
||||
|
||||
assert should_interrupt is True
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
|
||||
class TestStringTruncationHelper:
|
||||
"""Tests for cut_str helper method."""
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestCommandHandlerHelper:
|
||||
"""Tests for helper methods."""
|
||||
|
||||
def test_short_string_no_change(self):
|
||||
"""Short string is not truncated."""
|
||||
# Pattern from handler.cut_str
|
||||
def cut_str(s: str) -> str:
|
||||
s0 = s.split('\n')[0]
|
||||
if len(s0) > 20 or '\n' in s:
|
||||
s0 = s0[:20] + '...'
|
||||
return s0
|
||||
|
||||
result = cut_str('short text')
|
||||
def test_cut_str_short(self, fake_app):
|
||||
"""cut_str returns short string unchanged."""
|
||||
command = get_command_handler()
|
||||
handler = command.CommandHandler(fake_app)
|
||||
result = handler.cut_str('short text')
|
||||
assert result == 'short text'
|
||||
|
||||
def test_long_string_truncated(self):
|
||||
"""Long string is truncated."""
|
||||
def cut_str(s: str) -> str:
|
||||
s0 = s.split('\n')[0]
|
||||
if len(s0) > 20 or '\n' in s:
|
||||
s0 = s0[:20] + '...'
|
||||
return s0
|
||||
|
||||
result = cut_str('this is a very long string that exceeds twenty characters')
|
||||
def test_cut_str_long(self, fake_app):
|
||||
"""cut_str truncates long string."""
|
||||
command = get_command_handler()
|
||||
handler = command.CommandHandler(fake_app)
|
||||
result = handler.cut_str('this is a very long string that exceeds twenty characters')
|
||||
assert '...' in result
|
||||
assert len(result) <= 23
|
||||
|
||||
def test_multiline_truncated(self):
|
||||
"""Multiline string is truncated."""
|
||||
def cut_str(s: str) -> str:
|
||||
s0 = s.split('\n')[0]
|
||||
if len(s0) > 20 or '\n' in s:
|
||||
s0 = s0[:20] + '...'
|
||||
return s0
|
||||
|
||||
result = cut_str('first line\nsecond line\nthird')
|
||||
assert '...' in result
|
||||
|
||||
|
||||
class TestCommandPrefixConfiguration:
|
||||
"""Tests for command prefix configuration."""
|
||||
|
||||
def test_default_prefixes(self):
|
||||
"""Default prefixes are slash and exclamation."""
|
||||
default_prefixes = ['/', '!']
|
||||
assert '/' in default_prefixes
|
||||
assert '!' in default_prefixes
|
||||
|
||||
def test_custom_prefix(self):
|
||||
"""Custom prefix can be configured."""
|
||||
custom_prefix = '#'
|
||||
full_text = f'{custom_prefix}help'
|
||||
|
||||
# Would be checked against config['command']['prefix']
|
||||
is_command = full_text.startswith(custom_prefix)
|
||||
assert is_command is True
|
||||
def test_cut_str_multiline(self, fake_app):
|
||||
"""cut_str truncates multiline string."""
|
||||
command = get_command_handler()
|
||||
handler = command.CommandHandler(fake_app)
|
||||
result = handler.cut_str('first line\nsecond line')
|
||||
assert '...' in result
|
||||
162
tests/utils/import_isolation.py
Normal file
162
tests/utils/import_isolation.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
sys.modules isolation utilities for breaking circular import chains.
|
||||
|
||||
Provides safe, reversible sys.modules manipulation for tests that need to
|
||||
import modules with heavy import-time side effects (auto-registration,
|
||||
circular dependencies, etc.).
|
||||
|
||||
Usage pattern:
|
||||
1. Create mock objects for modules that cause circular imports
|
||||
2. Use isolated_sys_modules to temporarily patch sys.modules
|
||||
3. Import target module after patching
|
||||
4. Test the real production code
|
||||
5. Context manager automatically restores original sys.modules state
|
||||
|
||||
Key principle: mock only what breaks the import chain, not what the code needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import enum
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class MockLifecycleControlScope(enum.Enum):
|
||||
"""Mock enum for breaking circular import in core.entities."""
|
||||
APPLICATION = 'application'
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolated_sys_modules(
|
||||
mocks: dict[str, object],
|
||||
clear: list[str] | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager for isolated sys.modules manipulation.
|
||||
|
||||
Safely patches sys.modules with mocks and clears specified modules,
|
||||
then restores original state on exit. This prevents test pollution
|
||||
where mocks leak into subsequent tests.
|
||||
|
||||
Args:
|
||||
mocks: Dict mapping module names to mock objects.
|
||||
These will be set in sys.modules during the context.
|
||||
clear: List of module names to remove from sys.modules before
|
||||
entering the context. Useful for forcing re-import of
|
||||
modules that depend on mocked modules.
|
||||
|
||||
Example:
|
||||
>>> with isolated_sys_modules(
|
||||
... mocks={'my_pkg.heavy_module': MagicMock()},
|
||||
... clear=['my_pkg.target_module'],
|
||||
... ):
|
||||
... from my_pkg.target_module import MyClass # Safe import
|
||||
|
||||
Note:
|
||||
- Modules in both mocks and clear will be mocked (not cleared)
|
||||
- Original state is restored even if exception occurs
|
||||
- Modules not in sys.modules before context are removed after
|
||||
"""
|
||||
clear = clear or []
|
||||
touched = set(mocks.keys()) | set(clear)
|
||||
|
||||
# Save original state for modules we'll touch
|
||||
saved: dict[str, object] = {}
|
||||
for name in touched:
|
||||
if name in sys.modules:
|
||||
saved[name] = sys.modules[name]
|
||||
|
||||
try:
|
||||
# Clear modules first (force re-import)
|
||||
for name in clear:
|
||||
if name not in mocks: # Don't clear if we're mocking it
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
# Apply mocks
|
||||
for name, module in mocks.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore original state - critical for test isolation
|
||||
for name in touched:
|
||||
if name in saved:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
# Wasn't in sys.modules originally, remove it
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
|
||||
"""
|
||||
Create mock objects needed to break circular import chain in handlers.
|
||||
|
||||
The import chain:
|
||||
handler → core.app → pipeline.controller → http_controller
|
||||
→ groups/plugins → taskmgr (partial init)
|
||||
|
||||
This function creates minimal mocks that break this chain without
|
||||
affecting the handler's ability to use real pipeline.entities
|
||||
(needed for ResultType enum comparisons).
|
||||
|
||||
Returns:
|
||||
Dict mapping module names to MagicMock objects.
|
||||
|
||||
Note:
|
||||
These mocks are intentionally minimal - they only provide what's
|
||||
needed to prevent circular imports. The actual handler code uses
|
||||
real imports from langbot_plugin.api and langbot.pkg.pipeline.entities.
|
||||
"""
|
||||
# Mock core.entities with proper Enum class
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
# Mock core.app - Application class is referenced but not instantiated
|
||||
mock_app = MagicMock()
|
||||
|
||||
# Mock provider.runner - has preregistered_runners attribute
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.preregistered_runners = [] # Empty by default, tests override
|
||||
|
||||
# Mock utils.importutil - prevents auto-import of runners
|
||||
mock_importutil = MagicMock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
||||
|
||||
return {
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.pipeline.controller': MagicMock(),
|
||||
'langbot.pkg.pipeline.pipelinemgr': MagicMock(),
|
||||
'langbot.pkg.pipeline.process.process': MagicMock(),
|
||||
'langbot.pkg.provider.runner': mock_runner,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
}
|
||||
|
||||
|
||||
def get_handler_modules_to_clear(handler_name: str) -> list[str]:
|
||||
"""
|
||||
Get list of handler-related modules to clear before import.
|
||||
|
||||
These modules need to be cleared so they're re-imported after
|
||||
the circular import chain is mocked. Without clearing, they'd
|
||||
already be in sys.modules (possibly partially initialized).
|
||||
|
||||
Args:
|
||||
handler_name: The handler file name (e.g., 'chat', 'command')
|
||||
|
||||
Returns:
|
||||
List of module names to clear.
|
||||
"""
|
||||
return [
|
||||
'langbot.pkg.pipeline.process.handler',
|
||||
'langbot.pkg.pipeline.process.handlers',
|
||||
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
|
||||
]
|
||||
Reference in New Issue
Block a user