From 59871c3118f5e89e59506a0b2674f612e5efc1d3 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Fri, 8 May 2026 22:41:48 +0800 Subject: [PATCH] refactor(test): consolidate FakeApp and add sys.modules isolation utility - Extract tests/utils/import_isolation.py with isolated_sys_modules context manager - Extend tests/factories/app.py FakeApp with handler-specific attributes - Refactor test_chat_handler.py to use centralized FakeApp and cached imports - Refactor test_command_handler.py with mock_execute_factory fixture - Refactor test_smoke.py to move import-time sys.modules manipulation into fixture - Add SQLite migration integration tests (G-002) - Add HTTP API smoke integration tests (G-005) - Update CI workflow to call pytest for SQLite migrations (G-004) Co-Authored-By: Claude Opus 4.7 --- .github/workflows/test-migrations.yml | 56 +- tests/README.md | 56 +- tests/factories/app.py | 24 + tests/integration/__init__.py | 6 + tests/integration/api/__init__.py | 5 + tests/integration/api/test_smoke.py | 347 +++++++++ tests/integration/persistence/__init__.py | 5 + .../persistence/test_migrations.py | 223 ++++++ .../unit_tests/pipeline/test_chat_handler.py | 722 +++++++++--------- .../pipeline/test_command_handler.py | 592 ++++++++------ tests/utils/import_isolation.py | 162 ++++ 11 files changed, 1540 insertions(+), 658 deletions(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/api/__init__.py create mode 100644 tests/integration/api/test_smoke.py create mode 100644 tests/integration/persistence/__init__.py create mode 100644 tests/integration/persistence/test_migrations.py create mode 100644 tests/utils/import_isolation.py diff --git a/.github/workflows/test-migrations.yml b/.github/workflows/test-migrations.yml index fa2d30ae..2f2f2195 100644 --- a/.github/workflows/test-migrations.yml +++ b/.github/workflows/test-migrations.yml @@ -9,11 +9,13 @@ on: paths: - 'src/langbot/pkg/persistence/**' - 'src/langbot/pkg/entity/persistence/**' + - 'tests/integration/persistence/**' pull_request: types: [opened, synchronize, reopened, ready_for_review] paths: - 'src/langbot/pkg/persistence/**' - 'src/langbot/pkg/entity/persistence/**' + - 'tests/integration/persistence/**' jobs: test-migrations-sqlite: @@ -34,53 +36,13 @@ jobs: - name: Install dependencies run: uv sync --dev - - name: Test Alembic upgrade (SQLite) - run: | - uv run python -c " - import asyncio - from sqlalchemy.ext.asyncio import create_async_engine - from langbot.pkg.entity.persistence.base import Base - from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current - - async def main(): - engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db') - - # Create all tables (simulates existing DB) - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - # Stamp baseline - await run_alembic_stamp(engine, '0001_baseline') - rev = await get_alembic_current(engine) - assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}' - print(f'Stamped: {rev}') - - # Upgrade to head - await run_alembic_upgrade(engine, 'head') - rev = await get_alembic_current(engine) - print(f'After upgrade: {rev}') - assert rev is not None, 'Expected a revision after upgrade' - - # Verify idempotent - await run_alembic_upgrade(engine, 'head') - rev2 = await get_alembic_current(engine) - assert rev2 == rev, f'Expected {rev}, got {rev2}' - print(f'Idempotent check passed: {rev2}') - - # Fresh DB: upgrade from scratch - engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db') - async with engine2.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - await run_alembic_upgrade(engine2, 'head') - rev3 = await get_alembic_current(engine2) - print(f'Fresh DB upgrade: {rev3}') - assert rev3 is not None - - print('All SQLite migration tests passed!') - - asyncio.run(main()) - " + - name: Run SQLite migration tests + run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short + # TODO(G-003): Migrate PostgreSQL tests to pytest integration tests + # PostgreSQL requires external database service, which will be handled in G-003. + # The inline script below will be replaced with: + # uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short test-migrations-postgres: name: Migrations (PostgreSQL) runs-on: ubuntu-latest @@ -168,4 +130,4 @@ jobs: print('All PostgreSQL migration tests passed!') asyncio.run(main()) - " + " \ No newline at end of file diff --git a/tests/README.md b/tests/README.md index 8e40607b..70c90da9 100644 --- a/tests/README.md +++ b/tests/README.md @@ -17,6 +17,14 @@ tests/ │ ├── message.py # Message/query factories │ ├── provider.py # FakeProvider factory │ └── platform.py # FakePlatform factory +├── integration/ # Integration tests (real resources) +│ ├── __init__.py +│ ├── api/ # HTTP API tests +│ │ ├── __init__.py +│ │ └── test_smoke.py # API smoke tests +│ └── persistence/ # Database/persistence tests +│ ├── __init__.py +│ └── test_migrations.py # Alembic migration tests ├── smoke/ # Smoke tests (quick validation) │ └── test_fake_message_flow.py ├── unit_tests/ # Unit tests @@ -28,6 +36,9 @@ tests/ │ ├── plugin/ # Plugin system tests │ ├── provider/ # Provider tests │ └── storage/ # Storage tests +├── utils/ # Test utilities +│ ├── __init__.py +│ └── import_isolation.py # sys.modules isolation for circular imports └── README.md # This file ``` @@ -147,13 +158,51 @@ uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_ # Run only unit tests uv run pytest tests/unit_tests/ -m unit -# Run only integration tests (when available) -uv run pytest tests/ -m integration +# Run only integration tests +uv run pytest tests/integration/ -m integration + +# Run integration tests excluding slow ones +uv run pytest tests/integration/ -m "not slow" -q # Skip slow tests uv run pytest tests/unit_tests/ -m "not slow" ``` +### Running integration tests + +Integration tests validate real system behavior with actual database/network resources. + +```bash +# Run all integration tests (excluding slow ones) +uv run pytest tests/integration/ -m "not slow" -q + +# Run SQLite migration integration tests +uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short + +# Run API smoke integration tests +uv run pytest tests/integration/api/test_smoke.py -q + +# Run with verbose output +uv run pytest tests/integration/ -v +``` + +Note: Integration tests use: +- Temporary databases (tmp_path) for persistence tests +- Fake app/services for API tests (no real provider/platform) +- Do not require external services + +### Running migration tests locally + +SQLite migration tests can be run locally without any external dependencies: + +```bash +# SQLite migration tests (uses tmp_path, no external DB needed) +uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short +``` + +CI workflow `.github/workflows/test-migrations.yml` runs SQLite tests using pytest. +PostgreSQL migration tests still use inline Python script (will be migrated to pytest in G-003). + ### Known Issues Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues: @@ -250,7 +299,10 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun ## Future Enhancements +- [x] Add integration tests for database migrations (SQLite) +- [ ] Add PostgreSQL migration integration tests (G-003) - [ ] Add integration tests for full pipeline execution +- [x] Add API smoke integration tests - [ ] Add E2E tests - [ ] Add performance benchmarks - [ ] Add mutation testing for better coverage quality diff --git a/tests/factories/app.py b/tests/factories/app.py index cf1bbaf4..5f36df84 100644 --- a/tests/factories/app.py +++ b/tests/factories/app.py @@ -18,6 +18,8 @@ class FakeApp: command_prefix: list[str] = ["/", "!"], command_enable: bool = True, pipeline_concurrency: int = 10, + admins: list[str] | None = None, + **extra_attrs, ): self.logger = self._create_mock_logger() self.sess_mgr = self._create_mock_session_manager() @@ -30,9 +32,19 @@ class FakeApp: command_prefix=command_prefix, command_enable=command_enable, pipeline_concurrency=pipeline_concurrency, + admins=admins or [], ) self.task_mgr = self._create_mock_task_manager() + # Handler-specific optional attributes + self.telemetry = self._create_mock_telemetry() + self.survey = None + self.cmd_mgr = self._create_mock_cmd_mgr() + + # Apply any extra attributes for specific test scenarios + for name, value in extra_attrs.items(): + setattr(self, name, value) + # Captured outbound messages (for assertions) self._outbound_messages: list = [] @@ -82,11 +94,13 @@ class FakeApp: command_prefix: list[str], command_enable: bool, pipeline_concurrency: int, + admins: list[str], ): instance_config = Mock() instance_config.data = { "command": {"prefix": command_prefix, "enable": command_enable}, "concurrency": {"pipeline": pipeline_concurrency}, + "admins": admins, } return instance_config @@ -95,6 +109,16 @@ class FakeApp: task_mgr.create_task = Mock() return task_mgr + def _create_mock_telemetry(self): + telemetry = AsyncMock() + telemetry.start_send_task = AsyncMock() + return telemetry + + def _create_mock_cmd_mgr(self): + cmd_mgr = AsyncMock() + cmd_mgr.execute = AsyncMock() + return cmd_mgr + def capture_message(self, message): """Capture an outbound message for test assertions.""" self._outbound_messages.append(message) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..a261bc7b --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,6 @@ +""" +Integration tests package. + +These tests validate real system behavior with actual database/network resources. +Run with: uv run pytest tests/integration/ -m "not slow" -q +""" \ No newline at end of file diff --git a/tests/integration/api/__init__.py b/tests/integration/api/__init__.py new file mode 100644 index 00000000..99968664 --- /dev/null +++ b/tests/integration/api/__init__.py @@ -0,0 +1,5 @@ +""" +API integration tests package. + +Tests for HTTP API endpoints using Quart test client. +""" \ No newline at end of file diff --git a/tests/integration/api/test_smoke.py b/tests/integration/api/test_smoke.py new file mode 100644 index 00000000..18ea47aa --- /dev/null +++ b/tests/integration/api/test_smoke.py @@ -0,0 +1,347 @@ +""" +API smoke integration tests. + +Tests real HTTP API behavior using Quart test client. +Validates controller/service/routing wiring without real provider/platform. + +Run: uv run pytest tests/integration/api/test_smoke.py -q +""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock, AsyncMock, Mock + +from tests.factories import FakeApp + + +pytestmark = pytest.mark.integration + + +# ============== FIXTURE FOR SYS.MODULES ISOLATION ============== + +@pytest.fixture(scope='module') +def mock_circular_import_chain(): + """ + Break circular import chain for API controller using isolated_sys_modules. + + Chain: http_controller → groups/plugins → core.app → pipeline entities + + We need to mock core.app to prevent the circular chain when importing HTTPController. + But we must allow groups to be imported to populate preregistered_groups. + """ + from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope + + # Mock core.app with minimal Application that groups can reference + class FakeMinimalApplication: + pass + + mock_app = MagicMock() + mock_app.Application = FakeMinimalApplication + + # Mock core.entities with proper Enum + mock_entities = MagicMock() + mock_entities.LifecycleControlScope = MockLifecycleControlScope + + # Modules to clear (force re-import after mocking) + clear = [ + 'langbot.pkg.api.http.controller.group', + 'langbot.pkg.api.http.controller.groups', + 'langbot.pkg.api.http.controller.groups.system', + 'langbot.pkg.api.http.controller.groups.user', + 'langbot.pkg.api.http.controller.main', + ] + + with isolated_sys_modules( + mocks={ + 'langbot.pkg.core.app': mock_app, + 'langbot.pkg.core.entities': mock_entities, + }, + clear=clear, + ): + # Import groups after mocking core.app/core.entities + import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401 + import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401 + import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401 + + yield + + +# ============== FAKE APPLICATION FOR API TESTS ============== + +@pytest.fixture +def fake_api_app(): + """ + Create minimal FakeApp for API smoke tests with all required services. + + Uses tests.factories.FakeApp as base and adds API-specific services. + """ + app = FakeApp() + + # API-specific config + app.instance_config.data.update({ + 'api': {'port': 5300}, + 'plugin': {'enable_marketplace': True}, + 'space': {'url': 'https://space.langbot.app'}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + }) + + # API-specific services + app.user_service = Mock() + app.user_service.is_initialized = AsyncMock(return_value=False) + app.user_service.authenticate = AsyncMock(return_value='fake_token') + app.user_service.create_user = AsyncMock() + app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token')) + app.user_service.get_user_by_email = AsyncMock(return_value=Mock()) + app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token') + + app.apikey_service = Mock() + app.apikey_service.verify_api_key = AsyncMock(return_value=True) + + app.maintenance_service = Mock() + app.maintenance_service.get_storage_analysis = AsyncMock(return_value={}) + + app.plugin_connector.is_enable_plugin = False + app.plugin_connector.ping_plugin_runtime = AsyncMock() + + app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []}) + app.task_mgr.get_task_by_id = Mock(return_value=None) + + # Required by controller groups + app.model_mgr = Mock() + app.platform_mgr = Mock() + app.pipeline_pool = Mock() + app.pipeline_mgr = Mock() + + return app + + +# ============== QUART TEST CLIENT FIXTURE ============== + +@pytest.fixture +async def quart_test_client(fake_api_app): + """ + Create Quart test client with real HTTPController and route registration. + + Requires mock_circular_import_chain fixture to run first (usefixtures). + """ + from langbot.pkg.api.http.controller.main import HTTPController + + controller = HTTPController(fake_api_app) + await controller.initialize() + + client = controller.quart_app.test_client() + + yield client + + +# ============== API SMOKE TESTS ============== + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestHealthEndpoint: + """Tests for /healthz endpoint - simplest smoke test.""" + + @pytest.mark.asyncio + async def test_healthz_returns_ok(self, quart_test_client): + """ + /healthz endpoint returns {'code': 0, 'msg': 'ok'}. + + This tests: + - HTTPController instantiation + - Quart app creation + - Route registration + - Basic response handling + """ + response = await quart_test_client.get('/healthz') + + assert response.status_code == 200 + data = await response.get_json() + assert data == {'code': 0, 'msg': 'ok'} + + @pytest.mark.asyncio + async def test_healthz_no_auth_required(self, quart_test_client): + """ + /healthz doesn't require authentication. + + Tests that AuthType.NONE endpoints work without headers. + """ + response = await quart_test_client.get('/healthz') + assert response.status_code == 200 + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestSystemEndpoint: + """Tests for /api/v1/system endpoints.""" + + @pytest.mark.asyncio + async def test_system_info_no_auth(self, quart_test_client): + """ + /api/v1/system/info returns system information without auth. + + AuthType.NONE endpoint. + """ + response = await quart_test_client.get('/api/v1/system/info') + + assert response.status_code == 200 + data = await response.get_json() + + # Verify response structure + assert data['code'] == 0 + assert data['msg'] == 'ok' + assert 'data' in data + + # Verify expected fields + system_data = data['data'] + assert 'version' in system_data + assert 'debug' in system_data + assert 'edition' in system_data + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestProtectedEndpoints: + """Tests for authentication/authorization behavior.""" + + @pytest.mark.asyncio + async def test_protected_endpoint_rejects_no_token(self, quart_test_client): + """ + Protected endpoint (USER_TOKEN) returns 401 without auth. + + Tests that AuthType.USER_TOKEN properly rejects unauthorized requests. + """ + # /api/v1/user/check-token requires USER_TOKEN + response = await quart_test_client.get('/api/v1/user/check-token') + + assert response.status_code == 401 + data = await response.get_json() + + # Verify error response structure + assert data['code'] == -1 + assert 'msg' in data + + @pytest.mark.asyncio + async def test_protected_endpoint_with_invalid_token(self, quart_test_client): + """ + Protected endpoint returns 401 with invalid token. + """ + response = await quart_test_client.get( + '/api/v1/user/check-token', + headers={'Authorization': 'Bearer invalid_token'} + ) + + assert response.status_code == 401 + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestInvalidPayload: + """Tests for error handling with invalid payloads.""" + + @pytest.mark.asyncio + async def test_missing_json_body(self, quart_test_client): + """ + POST endpoint without JSON body handles gracefully. + """ + # /api/v1/user/auth expects JSON with 'user' and 'password' + response = await quart_test_client.post('/api/v1/user/auth') + + # Should return error (500, 400, or 401) with stable JSON structure + assert response.status_code in (400, 500, 401) + data = await response.get_json() + + # Verify error response has expected structure + assert 'code' in data + assert 'msg' in data + + @pytest.mark.asyncio + async def test_invalid_json_structure(self, quart_test_client): + """ + POST with wrong JSON structure returns stable error. + """ + response = await quart_test_client.post( + '/api/v1/user/auth', + json={'wrong_field': 'value'} + ) + + # Should return error with stable JSON structure + assert response.status_code in (400, 500, 401) + data = await response.get_json() + assert 'code' in data + assert 'msg' in data + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestUserInitEndpoint: + """Tests for /api/v1/user/init endpoint.""" + + @pytest.mark.asyncio + async def test_user_init_get_returns_not_initialized(self, quart_test_client): + """ + GET /api/v1/user/init returns initialized status. + + Uses fake user_service.is_initialized() = False. + """ + response = await quart_test_client.get('/api/v1/user/init') + + assert response.status_code == 200 + data = await response.get_json() + + assert data['code'] == 0 + assert data['msg'] == 'ok' + assert data['data']['initialized'] is False + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestRealImports: + """Tests that verify real production code is imported.""" + + def test_http_controller_real_import(self): + """ + Verify HTTPController is real production class, not mock. + """ + from langbot.pkg.api.http.controller.main import HTTPController + + assert HTTPController.__name__ == 'HTTPController' + assert hasattr(HTTPController, 'initialize') + assert hasattr(HTTPController, 'register_routes') + + def test_group_real_import(self): + """ + Verify RouterGroup and AuthType are real production classes. + """ + from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups + + assert RouterGroup.__name__ == 'RouterGroup' + assert hasattr(AuthType, 'NONE') + assert hasattr(AuthType, 'USER_TOKEN') + assert isinstance(preregistered_groups, list) + + def test_system_group_registered(self): + """ + Verify SystemRouterGroup is registered in preregistered_groups. + """ + from langbot.pkg.api.http.controller.group import preregistered_groups + + # Find system group + system_group = None + for g in preregistered_groups: + if g.name == 'system': + system_group = g + break + + assert system_group is not None + assert system_group.path == '/api/v1/system' + + def test_user_group_registered(self): + """ + Verify UserRouterGroup is registered in preregistered_groups. + """ + from langbot.pkg.api.http.controller.group import preregistered_groups + + # Find user group + user_group = None + for g in preregistered_groups: + if g.name == 'user': + user_group = g + break + + assert user_group is not None + assert user_group.path == '/api/v1/user' \ No newline at end of file diff --git a/tests/integration/persistence/__init__.py b/tests/integration/persistence/__init__.py new file mode 100644 index 00000000..496ef868 --- /dev/null +++ b/tests/integration/persistence/__init__.py @@ -0,0 +1,5 @@ +""" +Persistence integration tests package. + +Tests for database migrations and storage behavior. +""" \ No newline at end of file diff --git a/tests/integration/persistence/test_migrations.py b/tests/integration/persistence/test_migrations.py new file mode 100644 index 00000000..ff8473a1 --- /dev/null +++ b/tests/integration/persistence/test_migrations.py @@ -0,0 +1,223 @@ +""" +SQLite migration integration tests. + +Tests real Alembic migration behavior using temporary SQLite databases. +Validates the migration workflow from .github/workflows/test-migrations.yml. + +Run: uv run pytest tests/integration/persistence/test_migrations.py -q +""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import create_async_engine + +from langbot.pkg.entity.persistence.base import Base +from langbot.pkg.persistence.alembic_runner import ( + run_alembic_upgrade, + run_alembic_stamp, + get_alembic_current, +) + + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def sqlite_db_url(tmp_path): + """Create SQLite URL with temporary database file.""" + db_file = tmp_path / "test_migrations.db" + return f"sqlite+aiosqlite:///{db_file}" + + +@pytest.fixture +async def sqlite_engine(sqlite_db_url): + """Create async SQLite engine.""" + engine = create_async_engine(sqlite_db_url) + yield engine + await engine.dispose() + + +class TestSQLiteMigrationBaseline: + """Tests for baseline stamp workflow.""" + + @pytest.mark.asyncio + async def test_baseline_stamp_sets_revision(self, sqlite_engine): + """ + Stamp baseline on existing tables sets correct revision. + + Workflow: + 1. Create tables via Base.metadata.create_all + 2. Stamp with '0001_baseline' + 3. Verify current revision is '0001_baseline' + """ + # Create all tables (simulates existing DB created by ORM) + async with sqlite_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Stamp baseline + await run_alembic_stamp(sqlite_engine, '0001_baseline') + + # Verify revision + rev = await get_alembic_current(sqlite_engine) + assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}" + + @pytest.mark.asyncio + async def test_baseline_stamp_on_empty_db(self, sqlite_engine): + """ + Stamp on empty database (no tables) still sets revision. + + This is an edge case - stamping without tables. + """ + # Don't create tables - stamp directly + await run_alembic_stamp(sqlite_engine, '0001_baseline') + + rev = await get_alembic_current(sqlite_engine) + assert rev == '0001_baseline' + + +class TestSQLiteMigrationUpgrade: + """Tests for upgrade to head workflow.""" + + @pytest.mark.asyncio + async def test_upgrade_from_baseline_to_head(self, sqlite_engine): + """ + Upgrade from baseline to head applies all migrations. + + Workflow: + 1. Create tables + 2. Stamp baseline + 3. Upgrade to head + 4. Verify current revision is head + """ + # Create tables + async with sqlite_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Stamp baseline + await run_alembic_stamp(sqlite_engine, '0001_baseline') + + # Upgrade to head + await run_alembic_upgrade(sqlite_engine, 'head') + + # Verify revision + rev = await get_alembic_current(sqlite_engine) + assert rev is not None, "Expected a revision after upgrade" + # Head should be the latest migration + assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}" + + @pytest.mark.asyncio + async def test_upgrade_idempotent(self, sqlite_engine): + """ + Running upgrade to head multiple times is idempotent. + + Workflow: + 1. Upgrade to head + 2. Get revision + 3. Upgrade to head again + 4. Verify same revision + """ + # Create tables + async with sqlite_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Stamp and upgrade + await run_alembic_stamp(sqlite_engine, '0001_baseline') + await run_alembic_upgrade(sqlite_engine, 'head') + + rev1 = await get_alembic_current(sqlite_engine) + + # Upgrade again - should be idempotent + await run_alembic_upgrade(sqlite_engine, 'head') + + rev2 = await get_alembic_current(sqlite_engine) + assert rev2 == rev1, f"Expected {rev1}, got {rev2}" + + +class TestSQLiteMigrationFreshDatabase: + """Tests for fresh database workflow.""" + + @pytest.mark.asyncio + async def test_fresh_db_upgrade_from_scratch(self, tmp_path): + """ + Fresh database (no tables) can be upgraded directly to head. + + Workflow: + 1. Create fresh engine with new DB file + 2. Create tables + 3. Upgrade to head + 4. Verify revision + """ + # Use different DB file for fresh test + fresh_db_file = tmp_path / "test_migrations_fresh.db" + fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}" + fresh_engine = create_async_engine(fresh_url) + + # Create tables on fresh DB + async with fresh_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Upgrade to head directly (no baseline stamp) + await run_alembic_upgrade(fresh_engine, 'head') + + # Verify revision + rev = await get_alembic_current(fresh_engine) + assert rev is not None, "Expected a revision on fresh DB" + + await fresh_engine.dispose() + + @pytest.mark.asyncio + async def test_fresh_db_without_create_all_fails_gracefully(self, tmp_path): + """ + Fresh database without create_all may fail or have empty tables. + + This tests the edge case where migrations run on truly empty DB. + The behavior depends on migration script implementation. + """ + fresh_db_file = tmp_path / "test_empty_migrations.db" + fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}" + fresh_engine = create_async_engine(fresh_url) + + # Don't create tables - try upgrade directly + # This may fail if migrations expect tables to exist + try: + await run_alembic_upgrade(fresh_engine, 'head') + rev = await get_alembic_current(fresh_engine) + # If it succeeds, verify revision + assert rev is not None + except Exception: + # If it fails, that's acceptable behavior + # Migrations may require create_all first + pass + + await fresh_engine.dispose() + + +class TestSQLiteMigrationGetCurrent: + """Tests for get_alembic_current behavior.""" + + @pytest.mark.asyncio + async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine): + """ + get_alembic_current returns None for unstamped database. + """ + # Create tables but don't stamp + async with sqlite_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # No stamp - should return None + rev = await get_alembic_current(sqlite_engine) + assert rev is None, f"Expected None for unstamped DB, got {rev}" + + @pytest.mark.asyncio + async def test_get_current_after_stamp_returns_revision(self, sqlite_engine): + """ + get_alembic_current returns correct revision after stamp. + """ + async with sqlite_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + await run_alembic_stamp(sqlite_engine, '0001_baseline') + + rev = await get_alembic_current(sqlite_engine) + assert rev == '0001_baseline' \ No newline at end of file diff --git a/tests/unit_tests/pipeline/test_chat_handler.py b/tests/unit_tests/pipeline/test_chat_handler.py index 791fd021..097ef2b4 100644 --- a/tests/unit_tests/pipeline/test_chat_handler.py +++ b/tests/unit_tests/pipeline/test_chat_handler.py @@ -1,428 +1,436 @@ """ -Unit tests for ChatMessageHandler behavior patterns. +Unit tests for ChatMessageHandler - REAL imports. -Tests cover chat processing patterns: -- Event emission for normal messages -- Provider invocation pattern -- Streaming response handling -- Error handling - -Uses pattern-based testing to avoid circular import issues. +Tests the actual ChatMessageHandler class from production code. +Uses tests.utils.import_isolation to break circular import chain safely. """ from __future__ import annotations import pytest -from unittest.mock import Mock, AsyncMock -import uuid +from unittest.mock import AsyncMock, Mock -from tests.factories import text_query +from tests.factories import FakeApp -class TestNormalMessageEventPattern: - """Tests for normal message event emission.""" +# ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== - def test_person_event_type(self): - """Person messages use PersonNormalMessageReceived.""" - import langbot_plugin.api.entities.events as events - from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes +@pytest.fixture(scope='module') +def mock_circular_import_chain(): + """ + Break circular import chain using isolated_sys_modules. - launcher_type = LauncherTypes.PERSON + Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr - event_class = ( - events.PersonNormalMessageReceived - if launcher_type == LauncherTypes.PERSON - else events.GroupNormalMessageReceived - ) + Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation. + """ + from tests.utils.import_isolation import ( + isolated_sys_modules, + make_pipeline_handler_import_mocks, + get_handler_modules_to_clear, + ) + from langbot_plugin.api.entities.builtin.provider.message import Message - assert event_class == events.PersonNormalMessageReceived + mocks = make_pipeline_handler_import_mocks() - def test_group_event_type(self): - """Group messages use GroupNormalMessageReceived.""" - import langbot_plugin.api.entities.events as events - from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + # Create a default runner that yields a simple response + class DefaultRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + yield Message(role='assistant', content='fake response') - launcher_type = LauncherTypes.GROUP + mocks['langbot.pkg.provider.runner'].preregistered_runners = [DefaultRunner] - event_class = ( - events.PersonNormalMessageReceived - if launcher_type == LauncherTypes.PERSON - else events.GroupNormalMessageReceived - ) + clear = get_handler_modules_to_clear('chat') - assert event_class == events.GroupNormalMessageReceived - - def test_event_fields_pattern(self): - """Normal message event has expected fields.""" - launcher_type = 'person' - launcher_id = '12345' - sender_id = '12345' - text_message = 'hello world' - - event_data = { - 'launcher_type': launcher_type, - 'launcher_id': launcher_id, - 'sender_id': sender_id, - 'text_message': text_message, - } - - assert event_data['text_message'] == 'hello world' + with isolated_sys_modules(mocks=mocks, clear=clear): + yield -class TestPreventDefaultHandling: - """Tests for prevent_default handling in chat.""" +@pytest.fixture +def fake_app(): + """Create FakeApp instance.""" + return FakeApp() + + +@pytest.fixture +def mock_event_ctx(): + """Create mock event context.""" + ctx = Mock() + ctx.is_prevented_default = Mock(return_value=False) + ctx.event = Mock() + ctx.event.user_message_alter = None + ctx.event.reply_message_chain = None + return ctx + + +@pytest.fixture +def set_runner(): + """Factory fixture to set a custom runner for tests.""" + def _set_runner(runner_class): + import sys + sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class] + return _set_runner + + +# ============== CACHED LAZY IMPORTS ============== + +_chat_handler_module = None +_entities_module = None + + +def get_chat_handler(): + """Import ChatMessageHandler after circular import chain is mocked.""" + global _chat_handler_module + if _chat_handler_module is None: + from importlib import import_module + _chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat') + return _chat_handler_module + + +def get_entities(): + """Import pipeline entities - uses real module.""" + global _entities_module + if _entities_module is None: + from importlib import import_module + _entities_module = import_module('langbot.pkg.pipeline.entities') + return _entities_module + + +# ============== REAL ChatMessageHandler Tests ============== + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestChatMessageHandlerReal: + """Tests for real ChatMessageHandler class.""" @pytest.mark.asyncio - async def test_prevent_default_interrupts(self): - """prevent_default without reply interrupts pipeline.""" + async def test_real_import_works(self): + """Verify we can import the real handler class.""" + chat = get_chat_handler() + assert hasattr(chat, 'ChatMessageHandler') + handler_cls = chat.ChatMessageHandler + assert handler_cls.__name__ == 'ChatMessageHandler' - # Simulate event context - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = True - event_ctx.event = Mock() - event_ctx.event.reply_message_chain = None + @pytest.mark.asyncio + async def test_handler_creation(self, fake_app): + """ChatMessageHandler can be instantiated.""" + chat = get_chat_handler() + handler = chat.ChatMessageHandler(fake_app) + assert handler.ap is fake_app + @pytest.mark.asyncio + async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx): + """prevent_default without reply chain yields INTERRUPT.""" + from tests.factories import text_query + + chat = get_chat_handler() + entities = get_entities() + + mock_event_ctx.is_prevented_default.return_value = True + mock_event_ctx.event.reply_message_chain = None + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + handler = chat.ChatMessageHandler(fake_app) query = text_query('hello') - query.resp_messages = [] - should_interrupt = False - if event_ctx.is_prevented_default(): - if event_ctx.event.reply_message_chain is None: - should_interrupt = True - - assert should_interrupt is True - - @pytest.mark.asyncio - async def test_prevent_default_with_reply_continues(self): - """prevent_default with reply continues with that reply.""" - from tests.factories.message import text_chain - - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = True - event_ctx.event = Mock() - event_ctx.event.reply_message_chain = text_chain('plugin reply') - - query = text_query('hello') - query.resp_messages = [] - - if event_ctx.is_prevented_default(): - if event_ctx.event.reply_message_chain is not None: - query.resp_messages.append(event_ctx.event.reply_message_chain) - - assert len(query.resp_messages) == 1 - - -class TestUserMessageAlteration: - """Tests for user_message alteration pattern.""" - - @pytest.mark.asyncio - async def test_string_alters_message(self): - """User message can be altered to string.""" - import langbot_plugin.api.entities.builtin.provider.message as provider_message - - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = False - event_ctx.event = Mock() - event_ctx.event.user_message_alter = 'altered text' - - query = text_query('original') - query.user_message = provider_message.Message(role='user', content=[]) - - # Pattern from handler - if event_ctx.event.user_message_alter is not None: - if isinstance(event_ctx.event.user_message_alter, str): - query.user_message.content = [ - provider_message.ContentElement.from_text(event_ctx.event.user_message_alter) - ] - - assert query.user_message.content[0].text == 'altered text' - - @pytest.mark.asyncio - async def test_list_alters_message(self): - """User message can be altered to list.""" - import langbot_plugin.api.entities.builtin.provider.message as provider_message - - altered_list = [ - provider_message.ContentElement.from_text('part1'), - provider_message.ContentElement.from_text('part2'), - ] - - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = False - event_ctx.event = Mock() - event_ctx.event.user_message_alter = altered_list - - query = text_query('original') - query.user_message = provider_message.Message(role='user', content=[]) - - if isinstance(event_ctx.event.user_message_alter, list): - query.user_message.content = event_ctx.event.user_message_alter - - assert len(query.user_message.content) == 2 - - -class TestRunnerSelection: - """Tests for runner selection pattern.""" - - def test_runner_by_name(self): - """Runner is selected by name from config.""" - runner_name = 'local-agent' - - # Simulate preregistered runners lookup - Mock with name attribute - r1 = Mock() - r1.name = 'local-agent' - r2 = Mock() - r2.name = 'dify' - r3 = Mock() - r3.name = 'n8n' - preregistered_runners = [r1, r2, r3] - - runner = None - for r in preregistered_runners: - if r.name == runner_name: - runner = r - break - - assert runner is not None - assert runner.name == 'local-agent' - - def test_unknown_runner_raises(self): - """Unknown runner name raises error.""" - runner_name = 'unknown-runner' - preregistered_runners = [ - Mock(name='local-agent'), - Mock(name='dify'), - ] - - runner = None - for r in preregistered_runners: - if r.name == runner_name: - runner = r - break - - if runner is None: - error_raised = True - - assert error_raised is True - - -class TestStreamingResponse: - """Tests for streaming response pattern.""" - - @pytest.mark.asyncio - async def test_streaming_chunks_pattern(self): - """Streaming produces multiple chunks.""" - chunks = ['Hello', ' World', '!'] results = [] + async for result in handler.handle(query): + results.append(result) - # Simulate streaming generator - async def stream_gen(): - for chunk in chunks: - results.append(chunk) - - await stream_gen() - - assert len(results) == 3 - assert ''.join(results) == 'Hello World!' + assert len(results) == 1 + assert results[0].result_type == entities.ResultType.INTERRUPT @pytest.mark.asyncio - async def test_streaming_resp_message_id(self): - """Streaming uses uuid for resp_message_id.""" - resp_message_id = str(uuid.uuid4()) + async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx): + """prevent_default with reply yields CONTINUE and updates resp_messages.""" + from tests.factories import text_query, text_chain - assert len(resp_message_id) == 36 # UUID format + chat = get_chat_handler() + entities = get_entities() - @pytest.mark.asyncio - async def test_streaming_pop_previous(self): - """Streaming pops previous response before adding new.""" - query = text_query('test') - query.resp_messages = [Mock()] # Previous chunk - query.resp_message_chain = [Mock()] + reply_chain = text_chain('plugin reply') + mock_event_ctx.is_prevented_default.return_value = True + mock_event_ctx.event.reply_message_chain = reply_chain + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) - # Pattern from handler: pop before adding new chunk - if query.resp_messages: - query.resp_messages.pop() - if query.resp_message_chain: - query.resp_message_chain.pop() - - query.resp_messages.append(Mock()) # New chunk - - assert len(query.resp_messages) == 1 # Only new chunk - - -class TestNonStreamingResponse: - """Tests for non-streaming response pattern.""" - - @pytest.mark.asyncio - async def test_single_response_pattern(self): - """Non-streaming produces single response.""" - query = text_query('test') + handler = chat.ChatMessageHandler(fake_app) + query = text_query('hello') query.resp_messages = [] - # Simulate non-streaming runner - async def run(): - yield Mock(readable_str=lambda: 'response text') - - async for result in run(): - query.resp_messages.append(result) + results = [] + async for result in handler.handle(query): + results.append(result) + assert len(results) == 1 + assert results[0].result_type == entities.ResultType.CONTINUE assert len(query.resp_messages) == 1 - - -class TestExceptionHandling: - """Tests for exception handling pattern.""" + assert query.resp_messages[0] == reply_chain @pytest.mark.asyncio - async def test_exception_interrupts(self): - """Exception produces INTERRUPT result.""" + async def test_user_message_alter_string(self, fake_app, mock_event_ctx, set_runner): + """user_message_alter as string updates query.user_message.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message - text_query('test') - pipeline_config = { - 'output': { - 'misc': { - 'exception-handling': 'show-hint', - 'failure-hint': 'Request failed.', - } - } - } + chat = get_chat_handler() - # Simulate exception - exception = ValueError('provider error') + mock_event_ctx.is_prevented_default.return_value = False + mock_event_ctx.event.user_message_alter = 'altered text' + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) - exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') + query = text_query('original') + query.adapter = Mock() + query.adapter.is_stream_output_supported = AsyncMock(return_value=False) + query.user_message = Message(role='user', content=[]) - if exception_handling == 'show-error': - user_notice = f'{exception}' - elif exception_handling == 'show-hint': - user_notice = pipeline_config['output']['misc'].get('failure-hint', 'Request failed.') - else: # hide - user_notice = None + class QuickRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + yield Message(role='assistant', content='ok') - assert user_notice == 'Request failed.' + set_runner(QuickRunner) + + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert query.user_message.content is not None @pytest.mark.asyncio - async def test_exception_show_error(self): - """show-error mode shows actual error.""" - text_query('test') - pipeline_config = { - 'output': { - 'misc': { - 'exception-handling': 'show-error', - } - } - } + async def test_adapter_without_stream_method_defaults_non_stream(self, fake_app, mock_event_ctx, set_runner): + """Adapter without is_stream_output_supported defaults to non-stream.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement - exception = ValueError('API timeout') + chat = get_chat_handler() - exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') + mock_event_ctx.is_prevented_default.return_value = False + mock_event_ctx.event.user_message_alter = None + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) - if exception_handling == 'show-error': - user_notice = f'{exception}' - else: - user_notice = 'Request failed.' - - assert user_notice == 'API timeout' - - @pytest.mark.asyncio - async def test_exception_hide(self): - """hide mode shows no user notice.""" - text_query('test') - pipeline_config = { - 'output': { - 'misc': { - 'exception-handling': 'hide', - } - } - } - - ValueError('hidden error') - - exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') - - if exception_handling == 'hide': - user_notice = None - else: - user_notice = 'Error' - - assert user_notice is None - - -class TestMessageHistoryUpdate: - """Tests for conversation message history.""" - - @pytest.mark.asyncio - async def test_messages_appended_to_conversation(self): - """User message and response appended to conversation.""" query = text_query('test') - query.session = Mock() - query.session.using_conversation = Mock() - query.session.using_conversation.messages = [] + query.adapter = Mock(spec=[]) + query.user_message = Message(role='user', content=[ContentElement.from_text('test')]) - query.user_message = Mock() - query.resp_messages = [Mock(), Mock()] + class SingleRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + yield Message(role='assistant', content='response') - # Pattern from handler after successful response - query.session.using_conversation.messages.append(query.user_message) - query.session.using_conversation.messages.extend(query.resp_messages) + set_runner(SingleRunner) - assert len(query.session.using_conversation.messages) == 3 + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(results) >= 1 -class TestStreamOutputCheck: - """Tests for stream output support check.""" +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestChatHandlerStreaming: + """Tests for streaming behavior.""" @pytest.mark.asyncio - async def test_adapter_stream_check(self): - """Adapter is checked for stream support.""" - adapter = AsyncMock() - adapter.is_stream_output_supported = AsyncMock(return_value=True) + async def test_streaming_chunks_collected(self, fake_app, mock_event_ctx, set_runner): + """Streaming produces multiple results.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement, MessageChunk - is_stream = await adapter.is_stream_output_supported() + chat = get_chat_handler() - assert is_stream is True + mock_event_ctx.is_prevented_default.return_value = False + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + query = text_query('stream test') + query.adapter = Mock() + query.adapter.is_stream_output_supported = AsyncMock(return_value=True) + query.adapter.create_message_card = AsyncMock() + query.user_message = Message(role='user', content=[ContentElement.from_text('test')]) + + class StreamRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + yield MessageChunk(role='assistant', content='Hello', is_final=False) + yield MessageChunk(role='assistant', content=' World', is_final=True) + + set_runner(StreamRunner) + + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(results) >= 1 + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestChatHandlerExceptions: + """Tests for exception handling.""" @pytest.mark.asyncio - async def test_adapter_no_stream_method(self): - """Adapter without method defaults to False.""" - adapter = Mock(spec=[]) # Empty spec, no methods - # No is_stream_output_supported method + async def test_runner_exception_yields_interrupt(self, fake_app, mock_event_ctx, set_runner): + """Runner exception yields INTERRUPT with error notices.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message - is_stream = False - try: - if hasattr(adapter, 'is_stream_output_supported'): - is_stream = await adapter.is_stream_output_supported() - except AttributeError: - is_stream = False + chat = get_chat_handler() + entities = get_entities() - assert is_stream is False + mock_event_ctx.is_prevented_default.return_value = False + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + query = text_query('fail test') + query.adapter = Mock() + query.adapter.is_stream_output_supported = AsyncMock(return_value=False) + query.user_message = Message(role='user', content=[]) -class TestTelemetryPattern: - """Tests for telemetry reporting pattern.""" - - def test_telemetry_payload_fields(self): - """Telemetry payload has expected fields.""" - query_id = 123 - adapter_name = 'TestAdapter' - runner_name = 'local-agent' - duration_ms = 150 - - payload = { - 'query_id': query_id, - 'adapter': adapter_name, - 'runner': runner_name, - 'duration_ms': duration_ms, + query.pipeline_config = { + 'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}}, + 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, } - assert payload['query_id'] == 123 - assert payload['duration_ms'] == 150 + class FailingRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + raise ValueError('API error') + yield - def test_telemetry_error_included(self): - """Telemetry includes error info on failure.""" - error_info = 'Traceback...' + set_runner(FailingRunner) - payload = { - 'error': error_info, + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(results) == 1 + assert results[0].result_type == entities.ResultType.INTERRUPT + assert results[0].user_notice == 'Request failed.' + assert results[0].error_notice is not None + + @pytest.mark.asyncio + async def test_exception_show_error_mode(self, fake_app, mock_event_ctx, set_runner): + """show-error mode shows actual exception.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message + + chat = get_chat_handler() + + mock_event_ctx.is_prevented_default.return_value = False + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + query = text_query('error test') + query.adapter = Mock() + query.adapter.is_stream_output_supported = AsyncMock(return_value=False) + query.user_message = Message(role='user', content=[]) + + query.pipeline_config = { + 'output': {'misc': {'exception-handling': 'show-error'}}, + 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, } - assert payload['error'] == 'Traceback...' \ No newline at end of file + class ErrorRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + raise ValueError('Custom error') + yield + + set_runner(ErrorRunner) + + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert results[0].user_notice == 'Custom error' + + @pytest.mark.asyncio + async def test_exception_hide_mode(self, fake_app, mock_event_ctx, set_runner): + """hide mode shows no user notice.""" + from tests.factories import text_query + from langbot_plugin.api.entities.builtin.provider.message import Message + + chat = get_chat_handler() + + mock_event_ctx.is_prevented_default.return_value = False + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + query = text_query('hide test') + query.adapter = Mock() + query.adapter.is_stream_output_supported = AsyncMock(return_value=False) + query.user_message = Message(role='user', content=[]) + + query.pipeline_config = { + 'output': {'misc': {'exception-handling': 'hide'}}, + 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, + } + + class HideErrorRunner: + name = 'local-agent' + def __init__(self, app, config): + self.app = app + self.config = config + async def run(self, query): + raise RuntimeError('hidden') + yield + + set_runner(HideErrorRunner) + + handler = chat.ChatMessageHandler(fake_app) + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert results[0].user_notice is None + + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestChatHandlerHelper: + """Tests for helper methods.""" + + def test_cut_str_short(self, fake_app): + """cut_str returns short string unchanged.""" + chat = get_chat_handler() + handler = chat.ChatMessageHandler(fake_app) + result = handler.cut_str('short text') + assert result == 'short text' + + def test_cut_str_long(self, fake_app): + """cut_str truncates long string.""" + chat = get_chat_handler() + handler = chat.ChatMessageHandler(fake_app) + result = handler.cut_str('this is a very long string that exceeds twenty characters') + assert '...' in result + assert len(result) <= 23 + + def test_cut_str_multiline(self, fake_app): + """cut_str truncates multiline string.""" + chat = get_chat_handler() + handler = chat.ChatMessageHandler(fake_app) + result = handler.cut_str('first line\nsecond line') + assert '...' in result \ No newline at end of file diff --git a/tests/unit_tests/pipeline/test_command_handler.py b/tests/unit_tests/pipeline/test_command_handler.py index 6c686b94..5006d248 100644 --- a/tests/unit_tests/pipeline/test_command_handler.py +++ b/tests/unit_tests/pipeline/test_command_handler.py @@ -1,308 +1,396 @@ """ -Unit tests for CommandHandler behavior patterns. +Unit tests for CommandHandler - REAL imports. -Tests cover command processing patterns: -- Command parsing and routing -- Event emission pattern -- Command manager interaction -- Privilege handling - -Uses pattern-based testing to avoid circular import issues in source code. +Tests the actual CommandHandler class from production code. +Uses tests.utils.import_isolation to break circular import chain safely. """ from __future__ import annotations import pytest -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock -from tests.factories import command_query +from tests.factories import FakeApp, command_query -class TestCommandParsingPattern: - """Tests for command parsing logic.""" +# ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== - def test_command_text_extraction(self): +@pytest.fixture(scope='module') +def mock_circular_import_chain(): + """ + Break circular import chain using isolated_sys_modules. + + Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr + + Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation. + """ + from tests.utils.import_isolation import ( + isolated_sys_modules, + make_pipeline_handler_import_mocks, + get_handler_modules_to_clear, + ) + + mocks = make_pipeline_handler_import_mocks() + clear = get_handler_modules_to_clear('command') + + with isolated_sys_modules(mocks=mocks, clear=clear): + yield + + +@pytest.fixture +def fake_app(): + """Create FakeApp instance.""" + return FakeApp() + + +@pytest.fixture +def mock_event_ctx(): + """Create mock event context.""" + ctx = Mock() + ctx.is_prevented_default = Mock(return_value=False) + ctx.event = Mock() + ctx.event.reply_message_chain = None + return ctx + + +@pytest.fixture +def mock_execute_factory(): + """Factory fixture to create mock cmd_mgr.execute generators.""" + def _create_execute( + text: str | None = 'ok', + error: str | None = None, + image_url: str | None = None, + image_base64: str | None = None, + file_url: str | None = None, + ): + async def mock_execute(command_text, full_command_text, query, session): + ret = Mock() + ret.text = text + ret.error = error + ret.image_url = image_url + ret.image_base64 = image_base64 + ret.file_url = file_url + yield ret + return mock_execute + return _create_execute + + +# ============== CACHED LAZY IMPORTS ============== + +_command_handler_module = None +_entities_module = None + + +def get_command_handler(): + """Import CommandHandler after circular import chain is mocked.""" + global _command_handler_module + if _command_handler_module is None: + from importlib import import_module + _command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command') + return _command_handler_module + + +def get_entities(): + """Import pipeline entities - uses real module.""" + global _entities_module + if _entities_module is None: + from importlib import import_module + _entities_module = import_module('langbot.pkg.pipeline.entities') + return _entities_module + + +# ============== REAL CommandHandler Tests ============== + +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestCommandHandlerReal: + """Tests for real CommandHandler class.""" + + @pytest.mark.asyncio + async def test_real_import_works(self): + """Verify we can import the real handler class.""" + command = get_command_handler() + assert hasattr(command, 'CommandHandler') + handler_cls = command.CommandHandler + assert handler_cls.__name__ == 'CommandHandler' + + @pytest.mark.asyncio + async def test_handler_creation(self, fake_app): + """CommandHandler can be instantiated.""" + command = get_command_handler() + handler = command.CommandHandler(fake_app) + assert handler.ap is fake_app + + @pytest.mark.asyncio + async def test_command_parsing_extracts_command_name(self, fake_app, mock_event_ctx): """Command text is extracted after prefix.""" - # Simulate the parsing pattern from command handler - full_command_text = "/help arg1 arg2" + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) - # Handler strips first character (prefix) - command_text = full_command_text.strip()[1:] - parts = command_text.split(' ') + executed_commands = [] + async def track_execute(command_text, full_command_text, query, session): + executed_commands.append(command_text) + ret = Mock() + ret.text = 'ok' + ret.error = None + ret.image_url = None + ret.image_base64 = None + ret.file_url = None + yield ret - assert parts[0] == 'help' - assert parts[1:] == ['arg1', 'arg2'] + fake_app.cmd_mgr.execute = track_execute - def test_empty_command_parts(self): - """Empty command has no parts.""" - full_command_text = "/" + handler = command.CommandHandler(fake_app) + query = command_query('help arg1 arg2') - command_text = full_command_text.strip()[1:] - parts = command_text.split(' ') + results = [] + async for result in handler.handle(query): + results.append(result) - assert parts == [''] + assert executed_commands[0] == 'help arg1 arg2' - def test_single_command_no_args(self): - """Single command has no arguments.""" - full_command_text = "/status" - - command_text = full_command_text.strip()[1:] - parts = command_text.split(' ') - - assert parts == ['status'] - - -class TestCommandEventCreation: - """Tests for command event creation pattern.""" - - def test_event_type_by_launcher_type(self): - """Event type differs for person/group.""" - import langbot_plugin.api.entities.events as events - - # Person command - person_event_class = events.PersonCommandSent - - # Group command - group_event_class = events.GroupCommandSent - - assert person_event_class is not None - assert group_event_class is not None - - def test_event_fields_pattern(self): - """Command event should have expected fields.""" + @pytest.mark.asyncio + async def test_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory): + """Admin users get privilege level 2.""" from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes - launcher_type = LauncherTypes.PERSON.value - launcher_id = '12345' - sender_id = '12345' - command = 'help' - params = ['arg1', 'arg2'] - is_admin = False + command = get_command_handler() - # Simulate event creation pattern - event_data = { - 'launcher_type': launcher_type, - 'launcher_id': launcher_id, - 'sender_id': sender_id, - 'command': command, - 'params': params, - 'is_admin': is_admin, - } + fake_app.instance_config.data = {'admins': ['person_12345']} + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory() - assert event_data['command'] == 'help' - assert event_data['params'] == ['arg1', 'arg2'] + handler = command.CommandHandler(fake_app) + query = command_query('status') + query.launcher_type = LauncherTypes.PERSON + query.launcher_id = 12345 + results = [] + async for result in handler.handle(query): + results.append(result) -class TestPrivilegeCheckPattern: - """Tests for privilege/admin check.""" - - def test_admin_check_by_session_id(self): - """Admin is checked by session_id format.""" - admins = ['person_12345', 'group_99999'] - launcher_type = 'person' - launcher_id = '12345' - - session_id = f'{launcher_type}_{launcher_id}' - is_admin = session_id in admins - - assert is_admin is True - - def test_non_admin_check(self): - """Non-admin user has privilege 1.""" - admins = ['person_12345'] - launcher_type = 'person' - launcher_id = '67890' - - session_id = f'{launcher_type}_{launcher_id}' - is_admin = session_id in admins - - assert is_admin is False - - def test_privilege_levels(self): - """Privilege level 1 for normal, 2 for admin.""" - normal_privilege = 1 - admin_privilege = 2 - - admins = ['person_12345'] - - # Normal user - session_id = 'person_67890' - privilege = 2 if session_id in admins else 1 - assert privilege == normal_privilege - - # Admin user - session_id = 'person_12345' - privilege = 2 if session_id in admins else 1 - assert privilege == admin_privilege - - -class TestCommandResultHandling: - """Tests for command result handling patterns.""" + call_args = fake_app.plugin_connector.emit_event.call_args + event = call_args[0][0] + assert event.is_admin is True @pytest.mark.asyncio - async def test_text_result_pattern(self): - """Text result is converted to message.""" - import langbot_plugin.api.entities.builtin.provider.message as provider_message + async def test_non_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory): + """Non-admin users get privilege level 1.""" + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes - # Simulate command return - ret = Mock() - ret.text = 'Command output' - ret.error = None - ret.image_url = None - ret.image_base64 = None - ret.file_url = None + command = get_command_handler() - # Pattern from handler: build content list - content = [] - if ret.text is not None: - content.append(provider_message.ContentElement.from_text(ret.text)) + fake_app.instance_config.data = {'admins': ['person_12345']} + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory() - assert len(content) == 1 - assert content[0].type == 'text' - assert content[0].text == 'Command output' + handler = command.CommandHandler(fake_app) + query = command_query('status') + query.launcher_type = LauncherTypes.PERSON + query.launcher_id = 67890 + + results = [] + async for result in handler.handle(query): + results.append(result) + + call_args = fake_app.plugin_connector.emit_event.call_args + event = call_args[0][0] + assert event.is_admin is False @pytest.mark.asyncio - async def test_error_result_pattern(self): + async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx): + """prevent_default with reply yields CONTINUE.""" + from tests.factories.message import text_chain + + command = get_command_handler() + entities = get_entities() + + reply_chain = text_chain('plugin reply') + mock_event_ctx.is_prevented_default.return_value = True + mock_event_ctx.event.reply_message_chain = reply_chain + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + handler = command.CommandHandler(fake_app) + query = command_query('test') + query.resp_messages = [] + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(results) == 1 + assert results[0].result_type == entities.ResultType.CONTINUE + assert len(query.resp_messages) == 1 + assert query.resp_messages[0] == reply_chain + + @pytest.mark.asyncio + async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx): + """prevent_default without reply yields INTERRUPT.""" + command = get_command_handler() + entities = get_entities() + + mock_event_ctx.is_prevented_default.return_value = True + mock_event_ctx.event.reply_message_chain = None + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + handler = command.CommandHandler(fake_app) + query = command_query('test') + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(results) == 1 + assert results[0].result_type == entities.ResultType.INTERRUPT + + @pytest.mark.asyncio + async def test_event_type_person_command(self, fake_app, mock_event_ctx, mock_execute_factory): + """Person launcher creates PersonCommandSent event.""" + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + from langbot_plugin.api.entities import events + + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory() + + handler = command.CommandHandler(fake_app) + query = command_query('help') + query.launcher_type = LauncherTypes.PERSON + + results = [] + async for result in handler.handle(query): + results.append(result) + + call_args = fake_app.plugin_connector.emit_event.call_args + event = call_args[0][0] + assert isinstance(event, events.PersonCommandSent) + + @pytest.mark.asyncio + async def test_event_type_group_command(self, fake_app, mock_event_ctx, mock_execute_factory): + """Group launcher creates GroupCommandSent event.""" + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + from langbot_plugin.api.entities import events + + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory() + + handler = command.CommandHandler(fake_app) + query = command_query('help') + query.launcher_type = LauncherTypes.GROUP + + results = [] + async for result in handler.handle(query): + results.append(result) + + call_args = fake_app.plugin_connector.emit_event.call_args + event = call_args[0][0] + assert isinstance(event, events.GroupCommandSent) + + @pytest.mark.asyncio + async def test_command_result_text(self, fake_app, mock_event_ctx, mock_execute_factory): + """Text result is added to resp_messages.""" + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory(text='Command output') + + handler = command.CommandHandler(fake_app) + query = command_query('echo') + query.resp_messages = [] + + results = [] + async for result in handler.handle(query): + results.append(result) + + assert len(query.resp_messages) == 1 + msg = query.resp_messages[0] + assert msg.role == 'command' + assert len(msg.content) == 1 + assert msg.content[0].type == 'text' + assert msg.content[0].text == 'Command output' + + @pytest.mark.asyncio + async def test_command_result_error(self, fake_app, mock_event_ctx, mock_execute_factory): """Error result creates error message.""" - import langbot_plugin.api.entities.builtin.provider.message as provider_message + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory(text=None, error='Command failed') - ret = Mock() - ret.text = None - ret.error = 'Command failed' + handler = command.CommandHandler(fake_app) + query = command_query('fail') + query.resp_messages = [] - # Error handling pattern - if ret.error is not None: - msg = provider_message.Message( - role='command', - content=str(ret.error), - ) + results = [] + async for result in handler.handle(query): + results.append(result) + assert len(query.resp_messages) == 1 + msg = query.resp_messages[0] assert msg.role == 'command' assert msg.content == 'Command failed' @pytest.mark.asyncio - async def test_image_result_pattern(self): - """Image result is added to content.""" - import langbot_plugin.api.entities.builtin.provider.message as provider_message + async def test_command_result_image_url(self, fake_app, mock_event_ctx, mock_execute_factory): + """Image URL result is added to content.""" + command = get_command_handler() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory( + text='Here is the image:', + image_url='https://example.com/image.png' + ) - ret = Mock() - ret.text = 'Here is the image:' - ret.error = None - ret.image_url = 'https://example.com/image.png' - ret.image_base64 = None - ret.file_url = None - - content = [] - if ret.text is not None: - content.append(provider_message.ContentElement.from_text(ret.text)) - if ret.image_url is not None: - content.append(provider_message.ContentElement.from_image_url(ret.image_url)) - - assert len(content) == 2 - assert content[0].type == 'text' - assert content[1].type == 'image_url' - - -class TestPreventDefaultHandling: - """Tests for prevent_default handling.""" - - @pytest.mark.asyncio - async def test_prevent_default_with_reply(self): - """prevent_default with reply continues pipeline.""" - from tests.factories.message import text_chain - - # Simulate event context - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = True - event_ctx.event = Mock() - event_ctx.event.reply_message_chain = text_chain('plugin reply') - - query = command_query('test') + handler = command.CommandHandler(fake_app) + query = command_query('image') query.resp_messages = [] - # Pattern from handler - if event_ctx.is_prevented_default(): - if event_ctx.event.reply_message_chain is not None: - query.resp_messages.append(event_ctx.event.reply_message_chain) - # yield CONTINUE - else: - # yield INTERRUPT - pass + results = [] + async for result in handler.handle(query): + results.append(result) - assert len(query.resp_messages) == 1 + msg = query.resp_messages[0] + assert len(msg.content) == 2 + assert msg.content[0].type == 'text' + assert msg.content[1].type == 'image_url' @pytest.mark.asyncio - async def test_prevent_default_without_reply(self): - """prevent_default without reply interrupts.""" - event_ctx = Mock() - event_ctx.is_prevented_default.return_value = True - event_ctx.event = Mock() - event_ctx.event.reply_message_chain = None + async def test_command_result_empty_interrupts(self, fake_app, mock_event_ctx, mock_execute_factory): + """Empty result yields INTERRUPT.""" + command = get_command_handler() + entities = get_entities() + fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + fake_app.cmd_mgr.execute = mock_execute_factory(text=None) - query = command_query('test') - query.resp_messages = [] + handler = command.CommandHandler(fake_app) + query = command_query('empty') - should_interrupt = False - if event_ctx.is_prevented_default(): - if event_ctx.event.reply_message_chain is None: - should_interrupt = True + results = [] + async for result in handler.handle(query): + results.append(result) - assert should_interrupt is True + assert results[0].result_type == entities.ResultType.INTERRUPT -class TestStringTruncationHelper: - """Tests for cut_str helper method.""" +@pytest.mark.usefixtures('mock_circular_import_chain') +class TestCommandHandlerHelper: + """Tests for helper methods.""" - def test_short_string_no_change(self): - """Short string is not truncated.""" - # Pattern from handler.cut_str - def cut_str(s: str) -> str: - s0 = s.split('\n')[0] - if len(s0) > 20 or '\n' in s: - s0 = s0[:20] + '...' - return s0 - - result = cut_str('short text') + def test_cut_str_short(self, fake_app): + """cut_str returns short string unchanged.""" + command = get_command_handler() + handler = command.CommandHandler(fake_app) + result = handler.cut_str('short text') assert result == 'short text' - def test_long_string_truncated(self): - """Long string is truncated.""" - def cut_str(s: str) -> str: - s0 = s.split('\n')[0] - if len(s0) > 20 or '\n' in s: - s0 = s0[:20] + '...' - return s0 - - result = cut_str('this is a very long string that exceeds twenty characters') + def test_cut_str_long(self, fake_app): + """cut_str truncates long string.""" + command = get_command_handler() + handler = command.CommandHandler(fake_app) + result = handler.cut_str('this is a very long string that exceeds twenty characters') assert '...' in result assert len(result) <= 23 - def test_multiline_truncated(self): - """Multiline string is truncated.""" - def cut_str(s: str) -> str: - s0 = s.split('\n')[0] - if len(s0) > 20 or '\n' in s: - s0 = s0[:20] + '...' - return s0 - - result = cut_str('first line\nsecond line\nthird') - assert '...' in result - - -class TestCommandPrefixConfiguration: - """Tests for command prefix configuration.""" - - def test_default_prefixes(self): - """Default prefixes are slash and exclamation.""" - default_prefixes = ['/', '!'] - assert '/' in default_prefixes - assert '!' in default_prefixes - - def test_custom_prefix(self): - """Custom prefix can be configured.""" - custom_prefix = '#' - full_text = f'{custom_prefix}help' - - # Would be checked against config['command']['prefix'] - is_command = full_text.startswith(custom_prefix) - assert is_command is True \ No newline at end of file + def test_cut_str_multiline(self, fake_app): + """cut_str truncates multiline string.""" + command = get_command_handler() + handler = command.CommandHandler(fake_app) + result = handler.cut_str('first line\nsecond line') + assert '...' in result \ No newline at end of file diff --git a/tests/utils/import_isolation.py b/tests/utils/import_isolation.py new file mode 100644 index 00000000..bcf78d56 --- /dev/null +++ b/tests/utils/import_isolation.py @@ -0,0 +1,162 @@ +""" +sys.modules isolation utilities for breaking circular import chains. + +Provides safe, reversible sys.modules manipulation for tests that need to +import modules with heavy import-time side effects (auto-registration, +circular dependencies, etc.). + +Usage pattern: + 1. Create mock objects for modules that cause circular imports + 2. Use isolated_sys_modules to temporarily patch sys.modules + 3. Import target module after patching + 4. Test the real production code + 5. Context manager automatically restores original sys.modules state + +Key principle: mock only what breaks the import chain, not what the code needs. +""" + +from __future__ import annotations + +import sys +import enum +from contextlib import contextmanager +from typing import Generator +from unittest.mock import MagicMock + + +class MockLifecycleControlScope(enum.Enum): + """Mock enum for breaking circular import in core.entities.""" + APPLICATION = 'application' + PLATFORM = 'platform' + PLUGIN = 'plugin' + PROVIDER = 'provider' + + +@contextmanager +def isolated_sys_modules( + mocks: dict[str, object], + clear: list[str] | None = None, +) -> Generator[None, None, None]: + """ + Context manager for isolated sys.modules manipulation. + + Safely patches sys.modules with mocks and clears specified modules, + then restores original state on exit. This prevents test pollution + where mocks leak into subsequent tests. + + Args: + mocks: Dict mapping module names to mock objects. + These will be set in sys.modules during the context. + clear: List of module names to remove from sys.modules before + entering the context. Useful for forcing re-import of + modules that depend on mocked modules. + + Example: + >>> with isolated_sys_modules( + ... mocks={'my_pkg.heavy_module': MagicMock()}, + ... clear=['my_pkg.target_module'], + ... ): + ... from my_pkg.target_module import MyClass # Safe import + + Note: + - Modules in both mocks and clear will be mocked (not cleared) + - Original state is restored even if exception occurs + - Modules not in sys.modules before context are removed after + """ + clear = clear or [] + touched = set(mocks.keys()) | set(clear) + + # Save original state for modules we'll touch + saved: dict[str, object] = {} + for name in touched: + if name in sys.modules: + saved[name] = sys.modules[name] + + try: + # Clear modules first (force re-import) + for name in clear: + if name not in mocks: # Don't clear if we're mocking it + sys.modules.pop(name, None) + + # Apply mocks + for name, module in mocks.items(): + sys.modules[name] = module + + yield + + finally: + # Restore original state - critical for test isolation + for name in touched: + if name in saved: + sys.modules[name] = saved[name] + else: + # Wasn't in sys.modules originally, remove it + sys.modules.pop(name, None) + + +def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]: + """ + Create mock objects needed to break circular import chain in handlers. + + The import chain: + handler → core.app → pipeline.controller → http_controller + → groups/plugins → taskmgr (partial init) + + This function creates minimal mocks that break this chain without + affecting the handler's ability to use real pipeline.entities + (needed for ResultType enum comparisons). + + Returns: + Dict mapping module names to MagicMock objects. + + Note: + These mocks are intentionally minimal - they only provide what's + needed to prevent circular imports. The actual handler code uses + real imports from langbot_plugin.api and langbot.pkg.pipeline.entities. + """ + # Mock core.entities with proper Enum class + mock_entities = MagicMock() + mock_entities.LifecycleControlScope = MockLifecycleControlScope + + # Mock core.app - Application class is referenced but not instantiated + mock_app = MagicMock() + + # Mock provider.runner - has preregistered_runners attribute + mock_runner = MagicMock() + mock_runner.preregistered_runners = [] # Empty by default, tests override + + # Mock utils.importutil - prevents auto-import of runners + mock_importutil = MagicMock() + mock_importutil.import_modules_in_pkg = lambda pkg: None + mock_importutil.import_modules_in_pkgs = lambda pkgs: None + + return { + 'langbot.pkg.core.entities': mock_entities, + 'langbot.pkg.core.app': mock_app, + 'langbot.pkg.pipeline.controller': MagicMock(), + 'langbot.pkg.pipeline.pipelinemgr': MagicMock(), + 'langbot.pkg.pipeline.process.process': MagicMock(), + 'langbot.pkg.provider.runner': mock_runner, + 'langbot.pkg.utils.importutil': mock_importutil, + } + + +def get_handler_modules_to_clear(handler_name: str) -> list[str]: + """ + Get list of handler-related modules to clear before import. + + These modules need to be cleared so they're re-imported after + the circular import chain is mocked. Without clearing, they'd + already be in sys.modules (possibly partially initialized). + + Args: + handler_name: The handler file name (e.g., 'chat', 'command') + + Returns: + List of module names to clear. + """ + return [ + 'langbot.pkg.pipeline.process.handler', + 'langbot.pkg.pipeline.process.handlers', + f'langbot.pkg.pipeline.process.handlers.{handler_name}', + ] \ No newline at end of file