refactor(test): consolidate FakeApp and add sys.modules isolation utility

- Extract tests/utils/import_isolation.py with isolated_sys_modules context manager
- Extend tests/factories/app.py FakeApp with handler-specific attributes
- Refactor test_chat_handler.py to use centralized FakeApp and cached imports
- Refactor test_command_handler.py with mock_execute_factory fixture
- Refactor test_smoke.py to move import-time sys.modules manipulation into fixture
- Add SQLite migration integration tests (G-002)
- Add HTTP API smoke integration tests (G-005)
- Update CI workflow to call pytest for SQLite migrations (G-004)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
huanghuoguoguo
2026-05-08 22:41:48 +08:00
parent 3780a68dfa
commit 59871c3118
11 changed files with 1540 additions and 658 deletions

View File

@@ -9,11 +9,13 @@ on:
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
jobs:
test-migrations-sqlite:
@@ -34,53 +36,13 @@ jobs:
- name: Install dependencies
run: uv sync --dev
- name: Test Alembic upgrade (SQLite)
run: |
uv run python -c "
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
async def main():
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
# Create all tables (simulates existing DB)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(engine, '0001_baseline')
rev = await get_alembic_current(engine)
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
print(f'Stamped: {rev}')
# Upgrade to head
await run_alembic_upgrade(engine, 'head')
rev = await get_alembic_current(engine)
print(f'After upgrade: {rev}')
assert rev is not None, 'Expected a revision after upgrade'
# Verify idempotent
await run_alembic_upgrade(engine, 'head')
rev2 = await get_alembic_current(engine)
assert rev2 == rev, f'Expected {rev}, got {rev2}'
print(f'Idempotent check passed: {rev2}')
# Fresh DB: upgrade from scratch
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
async with engine2.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_upgrade(engine2, 'head')
rev3 = await get_alembic_current(engine2)
print(f'Fresh DB upgrade: {rev3}')
assert rev3 is not None
print('All SQLite migration tests passed!')
asyncio.run(main())
"
- name: Run SQLite migration tests
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
# TODO(G-003): Migrate PostgreSQL tests to pytest integration tests
# PostgreSQL requires external database service, which will be handled in G-003.
# The inline script below will be replaced with:
# uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
test-migrations-postgres:
name: Migrations (PostgreSQL)
runs-on: ubuntu-latest
@@ -168,4 +130,4 @@ jobs:
print('All PostgreSQL migration tests passed!')
asyncio.run(main())
"
"

View File

@@ -17,6 +17,14 @@ tests/
│ ├── message.py # Message/query factories
│ ├── provider.py # FakeProvider factory
│ └── platform.py # FakePlatform factory
├── integration/ # Integration tests (real resources)
│ ├── __init__.py
│ ├── api/ # HTTP API tests
│ │ ├── __init__.py
│ │ └── test_smoke.py # API smoke tests
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
@@ -28,6 +36,9 @@ tests/
│ ├── plugin/ # Plugin system tests
│ ├── provider/ # Provider tests
│ └── storage/ # Storage tests
├── utils/ # Test utilities
│ ├── __init__.py
│ └── import_isolation.py # sys.modules isolation for circular imports
└── README.md # This file
```
@@ -147,13 +158,51 @@ uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_
# Run only unit tests
uv run pytest tests/unit_tests/ -m unit
# Run only integration tests (when available)
uv run pytest tests/ -m integration
# Run only integration tests
uv run pytest tests/integration/ -m integration
# Run integration tests excluding slow ones
uv run pytest tests/integration/ -m "not slow" -q
# Skip slow tests
uv run pytest tests/unit_tests/ -m "not slow"
```
### Running integration tests
Integration tests validate real system behavior with actual database/network resources.
```bash
# Run all integration tests (excluding slow ones)
uv run pytest tests/integration/ -m "not slow" -q
# Run SQLite migration integration tests
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
# Run API smoke integration tests
uv run pytest tests/integration/api/test_smoke.py -q
# Run with verbose output
uv run pytest tests/integration/ -v
```
Note: Integration tests use:
- Temporary databases (tmp_path) for persistence tests
- Fake app/services for API tests (no real provider/platform)
- Do not require external services
### Running migration tests locally
SQLite migration tests can be run locally without any external dependencies:
```bash
# SQLite migration tests (uses tmp_path, no external DB needed)
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
```
CI workflow `.github/workflows/test-migrations.yml` runs SQLite tests using pytest.
PostgreSQL migration tests still use inline Python script (will be migrated to pytest in G-003).
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
@@ -250,7 +299,10 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
## Future Enhancements
- [x] Add integration tests for database migrations (SQLite)
- [ ] Add PostgreSQL migration integration tests (G-003)
- [ ] Add integration tests for full pipeline execution
- [x] Add API smoke integration tests
- [ ] Add E2E tests
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality

View File

@@ -18,6 +18,8 @@ class FakeApp:
command_prefix: list[str] = ["/", "!"],
command_enable: bool = True,
pipeline_concurrency: int = 10,
admins: list[str] | None = None,
**extra_attrs,
):
self.logger = self._create_mock_logger()
self.sess_mgr = self._create_mock_session_manager()
@@ -30,9 +32,19 @@ class FakeApp:
command_prefix=command_prefix,
command_enable=command_enable,
pipeline_concurrency=pipeline_concurrency,
admins=admins or [],
)
self.task_mgr = self._create_mock_task_manager()
# Handler-specific optional attributes
self.telemetry = self._create_mock_telemetry()
self.survey = None
self.cmd_mgr = self._create_mock_cmd_mgr()
# Apply any extra attributes for specific test scenarios
for name, value in extra_attrs.items():
setattr(self, name, value)
# Captured outbound messages (for assertions)
self._outbound_messages: list = []
@@ -82,11 +94,13 @@ class FakeApp:
command_prefix: list[str],
command_enable: bool,
pipeline_concurrency: int,
admins: list[str],
):
instance_config = Mock()
instance_config.data = {
"command": {"prefix": command_prefix, "enable": command_enable},
"concurrency": {"pipeline": pipeline_concurrency},
"admins": admins,
}
return instance_config
@@ -95,6 +109,16 @@ class FakeApp:
task_mgr.create_task = Mock()
return task_mgr
def _create_mock_telemetry(self):
telemetry = AsyncMock()
telemetry.start_send_task = AsyncMock()
return telemetry
def _create_mock_cmd_mgr(self):
cmd_mgr = AsyncMock()
cmd_mgr.execute = AsyncMock()
return cmd_mgr
def capture_message(self, message):
"""Capture an outbound message for test assertions."""
self._outbound_messages.append(message)

View File

@@ -0,0 +1,6 @@
"""
Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""

View File

@@ -0,0 +1,5 @@
"""
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""

View File

@@ -0,0 +1,347 @@
"""
API smoke integration tests.
Tests real HTTP API behavior using Quart test client.
Validates controller/service/routing wiring without real provider/platform.
Run: uv run pytest tests/integration/api/test_smoke.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain for API controller using isolated_sys_modules.
Chain: http_controller → groups/plugins → core.app → pipeline entities
We need to mock core.app to prevent the circular chain when importing HTTPController.
But we must allow groups to be imported to populate preregistered_groups.
"""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
# Mock core.app with minimal Application that groups can reference
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
# Mock core.entities with proper Enum
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
# Modules to clear (force re-import after mocking)
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.system',
'langbot.pkg.api.http.controller.groups.user',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
# Import groups after mocking core.app/core.entities
import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401
yield
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
Create minimal FakeApp for API smoke tests with all required services.
Uses tests.factories.FakeApp as base and adds API-specific services.
"""
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# API-specific services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=False)
app.user_service.authenticate = AsyncMock(return_value='fake_token')
app.user_service.create_user = AsyncMock()
app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token'))
app.user_service.get_user_by_email = AsyncMock(return_value=Mock())
app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token')
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
app.maintenance_service = Mock()
app.maintenance_service.get_storage_analysis = AsyncMock(return_value={})
app.plugin_connector.is_enable_plugin = False
app.plugin_connector.ping_plugin_runtime = AsyncMock()
app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []})
app.task_mgr.get_task_by_id = Mock(return_value=None)
# Required by controller groups
app.model_mgr = Mock()
app.platform_mgr = Mock()
app.pipeline_pool = Mock()
app.pipeline_mgr = Mock()
return app
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app):
"""
Create Quart test client with real HTTPController and route registration.
Requires mock_circular_import_chain fixture to run first (usefixtures).
"""
from langbot.pkg.api.http.controller.main import HTTPController
controller = HTTPController(fake_api_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@pytest.mark.asyncio
async def test_healthz_returns_ok(self, quart_test_client):
"""
/healthz endpoint returns {'code': 0, 'msg': 'ok'}.
This tests:
- HTTPController instantiation
- Quart app creation
- Route registration
- Basic response handling
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
data = await response.get_json()
assert data == {'code': 0, 'msg': 'ok'}
@pytest.mark.asyncio
async def test_healthz_no_auth_required(self, quart_test_client):
"""
/healthz doesn't require authentication.
Tests that AuthType.NONE endpoints work without headers.
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestSystemEndpoint:
"""Tests for /api/v1/system endpoints."""
@pytest.mark.asyncio
async def test_system_info_no_auth(self, quart_test_client):
"""
/api/v1/system/info returns system information without auth.
AuthType.NONE endpoint.
"""
response = await quart_test_client.get('/api/v1/system/info')
assert response.status_code == 200
data = await response.get_json()
# Verify response structure
assert data['code'] == 0
assert data['msg'] == 'ok'
assert 'data' in data
# Verify expected fields
system_data = data['data']
assert 'version' in system_data
assert 'debug' in system_data
assert 'edition' in system_data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProtectedEndpoints:
"""Tests for authentication/authorization behavior."""
@pytest.mark.asyncio
async def test_protected_endpoint_rejects_no_token(self, quart_test_client):
"""
Protected endpoint (USER_TOKEN) returns 401 without auth.
Tests that AuthType.USER_TOKEN properly rejects unauthorized requests.
"""
# /api/v1/user/check-token requires USER_TOKEN
response = await quart_test_client.get('/api/v1/user/check-token')
assert response.status_code == 401
data = await response.get_json()
# Verify error response structure
assert data['code'] == -1
assert 'msg' in data
@pytest.mark.asyncio
async def test_protected_endpoint_with_invalid_token(self, quart_test_client):
"""
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestInvalidPayload:
"""Tests for error handling with invalid payloads."""
@pytest.mark.asyncio
async def test_missing_json_body(self, quart_test_client):
"""
POST endpoint without JSON body handles gracefully.
"""
# /api/v1/user/auth expects JSON with 'user' and 'password'
response = await quart_test_client.post('/api/v1/user/auth')
# Should return error (500, 400, or 401) with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
# Verify error response has expected structure
assert 'code' in data
assert 'msg' in data
@pytest.mark.asyncio
async def test_invalid_json_structure(self, quart_test_client):
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
assert 'code' in data
assert 'msg' in data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestUserInitEndpoint:
"""Tests for /api/v1/user/init endpoint."""
@pytest.mark.asyncio
async def test_user_init_get_returns_not_initialized(self, quart_test_client):
"""
GET /api/v1/user/init returns initialized status.
Uses fake user_service.is_initialized() = False.
"""
response = await quart_test_client.get('/api/v1/user/init')
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert data['msg'] == 'ok'
assert data['data']['initialized'] is False
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRealImports:
"""Tests that verify real production code is imported."""
def test_http_controller_real_import(self):
"""
Verify HTTPController is real production class, not mock.
"""
from langbot.pkg.api.http.controller.main import HTTPController
assert HTTPController.__name__ == 'HTTPController'
assert hasattr(HTTPController, 'initialize')
assert hasattr(HTTPController, 'register_routes')
def test_group_real_import(self):
"""
Verify RouterGroup and AuthType are real production classes.
"""
from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups
assert RouterGroup.__name__ == 'RouterGroup'
assert hasattr(AuthType, 'NONE')
assert hasattr(AuthType, 'USER_TOKEN')
assert isinstance(preregistered_groups, list)
def test_system_group_registered(self):
"""
Verify SystemRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find system group
system_group = None
for g in preregistered_groups:
if g.name == 'system':
system_group = g
break
assert system_group is not None
assert system_group.path == '/api/v1/system'
def test_user_group_registered(self):
"""
Verify UserRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find user group
user_group = None
for g in preregistered_groups:
if g.name == 'user':
user_group = g
break
assert user_group is not None
assert user_group.path == '/api/v1/user'

View File

@@ -0,0 +1,5 @@
"""
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""

View File

@@ -0,0 +1,223 @@
"""
SQLite migration integration tests.
Tests real Alembic migration behavior using temporary SQLite databases.
Validates the migration workflow from .github/workflows/test-migrations.yml.
Run: uv run pytest tests/integration/persistence/test_migrations.py -q
"""
from __future__ import annotations
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import (
run_alembic_upgrade,
run_alembic_stamp,
get_alembic_current,
)
pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
@pytest.fixture
async def sqlite_engine(sqlite_db_url):
"""Create async SQLite engine."""
engine = create_async_engine(sqlite_db_url)
yield engine
await engine.dispose()
class TestSQLiteMigrationBaseline:
"""Tests for baseline stamp workflow."""
@pytest.mark.asyncio
async def test_baseline_stamp_sets_revision(self, sqlite_engine):
"""
Stamp baseline on existing tables sets correct revision.
Workflow:
1. Create tables via Base.metadata.create_all
2. Stamp with '0001_baseline'
3. Verify current revision is '0001_baseline'
"""
# Create all tables (simulates existing DB created by ORM)
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_baseline_stamp_on_empty_db(self, sqlite_engine):
"""
Stamp on empty database (no tables) still sets revision.
This is an edge case - stamping without tables.
"""
# Don't create tables - stamp directly
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
class TestSQLiteMigrationUpgrade:
"""Tests for upgrade to head workflow."""
@pytest.mark.asyncio
async def test_upgrade_from_baseline_to_head(self, sqlite_engine):
"""
Upgrade from baseline to head applies all migrations.
Workflow:
1. Create tables
2. Stamp baseline
3. Upgrade to head
4. Verify current revision is head
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Upgrade to head
await run_alembic_upgrade(sqlite_engine, 'head')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
"""
Running upgrade to head multiple times is idempotent.
Workflow:
1. Upgrade to head
2. Get revision
3. Upgrade to head again
4. Verify same revision
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp and upgrade
await run_alembic_stamp(sqlite_engine, '0001_baseline')
await run_alembic_upgrade(sqlite_engine, 'head')
rev1 = await get_alembic_current(sqlite_engine)
# Upgrade again - should be idempotent
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestSQLiteMigrationFreshDatabase:
"""Tests for fresh database workflow."""
@pytest.mark.asyncio
async def test_fresh_db_upgrade_from_scratch(self, tmp_path):
"""
Fresh database (no tables) can be upgraded directly to head.
Workflow:
1. Create fresh engine with new DB file
2. Create tables
3. Upgrade to head
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
async with fresh_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Upgrade to head directly (no baseline stamp)
await run_alembic_upgrade(fresh_engine, 'head')
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
await fresh_engine.dispose()
@pytest.mark.asyncio
async def test_fresh_db_without_create_all_fails_gracefully(self, tmp_path):
"""
Fresh database without create_all may fail or have empty tables.
This tests the edge case where migrations run on truly empty DB.
The behavior depends on migration script implementation.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Don't create tables - try upgrade directly
# This may fail if migrations expect tables to exist
try:
await run_alembic_upgrade(fresh_engine, 'head')
rev = await get_alembic_current(fresh_engine)
# If it succeeds, verify revision
assert rev is not None
except Exception:
# If it fails, that's acceptable behavior
# Migrations may require create_all first
pass
await fresh_engine.dispose()
class TestSQLiteMigrationGetCurrent:
"""Tests for get_alembic_current behavior."""
@pytest.mark.asyncio
async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine):
"""
get_alembic_current returns None for unstamped database.
"""
# Create tables but don't stamp
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
"""
get_alembic_current returns correct revision after stamp.
"""
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'

View File

@@ -1,428 +1,436 @@
"""
Unit tests for ChatMessageHandler behavior patterns.
Unit tests for ChatMessageHandler - REAL imports.
Tests cover chat processing patterns:
- Event emission for normal messages
- Provider invocation pattern
- Streaming response handling
- Error handling
Uses pattern-based testing to avoid circular import issues.
Tests the actual ChatMessageHandler class from production code.
Uses tests.utils.import_isolation to break circular import chain safely.
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
import uuid
from unittest.mock import AsyncMock, Mock
from tests.factories import text_query
from tests.factories import FakeApp
class TestNormalMessageEventPattern:
"""Tests for normal message event emission."""
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
def test_person_event_type(self):
"""Person messages use PersonNormalMessageReceived."""
import langbot_plugin.api.entities.events as events
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain using isolated_sys_modules.
launcher_type = LauncherTypes.PERSON
Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr
event_class = (
events.PersonNormalMessageReceived
if launcher_type == LauncherTypes.PERSON
else events.GroupNormalMessageReceived
)
Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation.
"""
from tests.utils.import_isolation import (
isolated_sys_modules,
make_pipeline_handler_import_mocks,
get_handler_modules_to_clear,
)
from langbot_plugin.api.entities.builtin.provider.message import Message
assert event_class == events.PersonNormalMessageReceived
mocks = make_pipeline_handler_import_mocks()
def test_group_event_type(self):
"""Group messages use GroupNormalMessageReceived."""
import langbot_plugin.api.entities.events as events
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
launcher_type = LauncherTypes.GROUP
mocks['langbot.pkg.provider.runner'].preregistered_runners = [DefaultRunner]
event_class = (
events.PersonNormalMessageReceived
if launcher_type == LauncherTypes.PERSON
else events.GroupNormalMessageReceived
)
clear = get_handler_modules_to_clear('chat')
assert event_class == events.GroupNormalMessageReceived
def test_event_fields_pattern(self):
"""Normal message event has expected fields."""
launcher_type = 'person'
launcher_id = '12345'
sender_id = '12345'
text_message = 'hello world'
event_data = {
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'text_message': text_message,
}
assert event_data['text_message'] == 'hello world'
with isolated_sys_modules(mocks=mocks, clear=clear):
yield
class TestPreventDefaultHandling:
"""Tests for prevent_default handling in chat."""
@pytest.fixture
def fake_app():
"""Create FakeApp instance."""
return FakeApp()
@pytest.fixture
def mock_event_ctx():
"""Create mock event context."""
ctx = Mock()
ctx.is_prevented_default = Mock(return_value=False)
ctx.event = Mock()
ctx.event.user_message_alter = None
ctx.event.reply_message_chain = None
return ctx
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
# ============== CACHED LAZY IMPORTS ==============
_chat_handler_module = None
_entities_module = None
def get_chat_handler():
"""Import ChatMessageHandler after circular import chain is mocked."""
global _chat_handler_module
if _chat_handler_module is None:
from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module
def get_entities():
"""Import pipeline entities - uses real module."""
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class."""
@pytest.mark.asyncio
async def test_prevent_default_interrupts(self):
"""prevent_default without reply interrupts pipeline."""
async def test_real_import_works(self):
"""Verify we can import the real handler class."""
chat = get_chat_handler()
assert hasattr(chat, 'ChatMessageHandler')
handler_cls = chat.ChatMessageHandler
assert handler_cls.__name__ == 'ChatMessageHandler'
# Simulate event context
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = True
event_ctx.event = Mock()
event_ctx.event.reply_message_chain = None
@pytest.mark.asyncio
async def test_handler_creation(self, fake_app):
"""ChatMessageHandler can be instantiated."""
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
assert handler.ap is fake_app
@pytest.mark.asyncio
async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx):
"""prevent_default without reply chain yields INTERRUPT."""
from tests.factories import text_query
chat = get_chat_handler()
entities = get_entities()
mock_event_ctx.is_prevented_default.return_value = True
mock_event_ctx.event.reply_message_chain = None
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
handler = chat.ChatMessageHandler(fake_app)
query = text_query('hello')
query.resp_messages = []
should_interrupt = False
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is None:
should_interrupt = True
assert should_interrupt is True
@pytest.mark.asyncio
async def test_prevent_default_with_reply_continues(self):
"""prevent_default with reply continues with that reply."""
from tests.factories.message import text_chain
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = True
event_ctx.event = Mock()
event_ctx.event.reply_message_chain = text_chain('plugin reply')
query = text_query('hello')
query.resp_messages = []
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
query.resp_messages.append(event_ctx.event.reply_message_chain)
assert len(query.resp_messages) == 1
class TestUserMessageAlteration:
"""Tests for user_message alteration pattern."""
@pytest.mark.asyncio
async def test_string_alters_message(self):
"""User message can be altered to string."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = False
event_ctx.event = Mock()
event_ctx.event.user_message_alter = 'altered text'
query = text_query('original')
query.user_message = provider_message.Message(role='user', content=[])
# Pattern from handler
if event_ctx.event.user_message_alter is not None:
if isinstance(event_ctx.event.user_message_alter, str):
query.user_message.content = [
provider_message.ContentElement.from_text(event_ctx.event.user_message_alter)
]
assert query.user_message.content[0].text == 'altered text'
@pytest.mark.asyncio
async def test_list_alters_message(self):
"""User message can be altered to list."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
altered_list = [
provider_message.ContentElement.from_text('part1'),
provider_message.ContentElement.from_text('part2'),
]
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = False
event_ctx.event = Mock()
event_ctx.event.user_message_alter = altered_list
query = text_query('original')
query.user_message = provider_message.Message(role='user', content=[])
if isinstance(event_ctx.event.user_message_alter, list):
query.user_message.content = event_ctx.event.user_message_alter
assert len(query.user_message.content) == 2
class TestRunnerSelection:
"""Tests for runner selection pattern."""
def test_runner_by_name(self):
"""Runner is selected by name from config."""
runner_name = 'local-agent'
# Simulate preregistered runners lookup - Mock with name attribute
r1 = Mock()
r1.name = 'local-agent'
r2 = Mock()
r2.name = 'dify'
r3 = Mock()
r3.name = 'n8n'
preregistered_runners = [r1, r2, r3]
runner = None
for r in preregistered_runners:
if r.name == runner_name:
runner = r
break
assert runner is not None
assert runner.name == 'local-agent'
def test_unknown_runner_raises(self):
"""Unknown runner name raises error."""
runner_name = 'unknown-runner'
preregistered_runners = [
Mock(name='local-agent'),
Mock(name='dify'),
]
runner = None
for r in preregistered_runners:
if r.name == runner_name:
runner = r
break
if runner is None:
error_raised = True
assert error_raised is True
class TestStreamingResponse:
"""Tests for streaming response pattern."""
@pytest.mark.asyncio
async def test_streaming_chunks_pattern(self):
"""Streaming produces multiple chunks."""
chunks = ['Hello', ' World', '!']
results = []
async for result in handler.handle(query):
results.append(result)
# Simulate streaming generator
async def stream_gen():
for chunk in chunks:
results.append(chunk)
await stream_gen()
assert len(results) == 3
assert ''.join(results) == 'Hello World!'
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_streaming_resp_message_id(self):
"""Streaming uses uuid for resp_message_id."""
resp_message_id = str(uuid.uuid4())
async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx):
"""prevent_default with reply yields CONTINUE and updates resp_messages."""
from tests.factories import text_query, text_chain
assert len(resp_message_id) == 36 # UUID format
chat = get_chat_handler()
entities = get_entities()
@pytest.mark.asyncio
async def test_streaming_pop_previous(self):
"""Streaming pops previous response before adding new."""
query = text_query('test')
query.resp_messages = [Mock()] # Previous chunk
query.resp_message_chain = [Mock()]
reply_chain = text_chain('plugin reply')
mock_event_ctx.is_prevented_default.return_value = True
mock_event_ctx.event.reply_message_chain = reply_chain
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Pattern from handler: pop before adding new chunk
if query.resp_messages:
query.resp_messages.pop()
if query.resp_message_chain:
query.resp_message_chain.pop()
query.resp_messages.append(Mock()) # New chunk
assert len(query.resp_messages) == 1 # Only new chunk
class TestNonStreamingResponse:
"""Tests for non-streaming response pattern."""
@pytest.mark.asyncio
async def test_single_response_pattern(self):
"""Non-streaming produces single response."""
query = text_query('test')
handler = chat.ChatMessageHandler(fake_app)
query = text_query('hello')
query.resp_messages = []
# Simulate non-streaming runner
async def run():
yield Mock(readable_str=lambda: 'response text')
async for result in run():
query.resp_messages.append(result)
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(query.resp_messages) == 1
class TestExceptionHandling:
"""Tests for exception handling pattern."""
assert query.resp_messages[0] == reply_chain
@pytest.mark.asyncio
async def test_exception_interrupts(self):
"""Exception produces INTERRUPT result."""
async def test_user_message_alter_string(self, fake_app, mock_event_ctx, set_runner):
"""user_message_alter as string updates query.user_message."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message
text_query('test')
pipeline_config = {
'output': {
'misc': {
'exception-handling': 'show-hint',
'failure-hint': 'Request failed.',
}
}
}
chat = get_chat_handler()
# Simulate exception
exception = ValueError('provider error')
mock_event_ctx.is_prevented_default.return_value = False
mock_event_ctx.event.user_message_alter = 'altered text'
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
query = text_query('original')
query.adapter = Mock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[])
if exception_handling == 'show-error':
user_notice = f'{exception}'
elif exception_handling == 'show-hint':
user_notice = pipeline_config['output']['misc'].get('failure-hint', 'Request failed.')
else: # hide
user_notice = None
class QuickRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='ok')
assert user_notice == 'Request failed.'
set_runner(QuickRunner)
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert query.user_message.content is not None
@pytest.mark.asyncio
async def test_exception_show_error(self):
"""show-error mode shows actual error."""
text_query('test')
pipeline_config = {
'output': {
'misc': {
'exception-handling': 'show-error',
}
}
}
async def test_adapter_without_stream_method_defaults_non_stream(self, fake_app, mock_event_ctx, set_runner):
"""Adapter without is_stream_output_supported defaults to non-stream."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement
exception = ValueError('API timeout')
chat = get_chat_handler()
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
mock_event_ctx.is_prevented_default.return_value = False
mock_event_ctx.event.user_message_alter = None
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
if exception_handling == 'show-error':
user_notice = f'{exception}'
else:
user_notice = 'Request failed.'
assert user_notice == 'API timeout'
@pytest.mark.asyncio
async def test_exception_hide(self):
"""hide mode shows no user notice."""
text_query('test')
pipeline_config = {
'output': {
'misc': {
'exception-handling': 'hide',
}
}
}
ValueError('hidden error')
exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint')
if exception_handling == 'hide':
user_notice = None
else:
user_notice = 'Error'
assert user_notice is None
class TestMessageHistoryUpdate:
"""Tests for conversation message history."""
@pytest.mark.asyncio
async def test_messages_appended_to_conversation(self):
"""User message and response appended to conversation."""
query = text_query('test')
query.session = Mock()
query.session.using_conversation = Mock()
query.session.using_conversation.messages = []
query.adapter = Mock(spec=[])
query.user_message = Message(role='user', content=[ContentElement.from_text('test')])
query.user_message = Mock()
query.resp_messages = [Mock(), Mock()]
class SingleRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='response')
# Pattern from handler after successful response
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
set_runner(SingleRunner)
assert len(query.session.using_conversation.messages) == 3
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) >= 1
class TestStreamOutputCheck:
"""Tests for stream output support check."""
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatHandlerStreaming:
"""Tests for streaming behavior."""
@pytest.mark.asyncio
async def test_adapter_stream_check(self):
"""Adapter is checked for stream support."""
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=True)
async def test_streaming_chunks_collected(self, fake_app, mock_event_ctx, set_runner):
"""Streaming produces multiple results."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message, ContentElement, MessageChunk
is_stream = await adapter.is_stream_output_supported()
chat = get_chat_handler()
assert is_stream is True
mock_event_ctx.is_prevented_default.return_value = False
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
query = text_query('stream test')
query.adapter = Mock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=True)
query.adapter.create_message_card = AsyncMock()
query.user_message = Message(role='user', content=[ContentElement.from_text('test')])
class StreamRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True)
set_runner(StreamRunner)
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) >= 1
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatHandlerExceptions:
"""Tests for exception handling."""
@pytest.mark.asyncio
async def test_adapter_no_stream_method(self):
"""Adapter without method defaults to False."""
adapter = Mock(spec=[]) # Empty spec, no methods
# No is_stream_output_supported method
async def test_runner_exception_yields_interrupt(self, fake_app, mock_event_ctx, set_runner):
"""Runner exception yields INTERRUPT with error notices."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message
is_stream = False
try:
if hasattr(adapter, 'is_stream_output_supported'):
is_stream = await adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
chat = get_chat_handler()
entities = get_entities()
assert is_stream is False
mock_event_ctx.is_prevented_default.return_value = False
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
query = text_query('fail test')
query.adapter = Mock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[])
class TestTelemetryPattern:
"""Tests for telemetry reporting pattern."""
def test_telemetry_payload_fields(self):
"""Telemetry payload has expected fields."""
query_id = 123
adapter_name = 'TestAdapter'
runner_name = 'local-agent'
duration_ms = 150
payload = {
'query_id': query_id,
'adapter': adapter_name,
'runner': runner_name,
'duration_ms': duration_ms,
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
}
assert payload['query_id'] == 123
assert payload['duration_ms'] == 150
class FailingRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('API error')
yield
def test_telemetry_error_included(self):
"""Telemetry includes error info on failure."""
error_info = 'Traceback...'
set_runner(FailingRunner)
payload = {
'error': error_info,
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert results[0].user_notice == 'Request failed.'
assert results[0].error_notice is not None
@pytest.mark.asyncio
async def test_exception_show_error_mode(self, fake_app, mock_event_ctx, set_runner):
"""show-error mode shows actual exception."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message
chat = get_chat_handler()
mock_event_ctx.is_prevented_default.return_value = False
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
query = text_query('error test')
query.adapter = Mock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[])
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
}
assert payload['error'] == 'Traceback...'
class ErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('Custom error')
yield
set_runner(ErrorRunner)
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert results[0].user_notice == 'Custom error'
@pytest.mark.asyncio
async def test_exception_hide_mode(self, fake_app, mock_event_ctx, set_runner):
"""hide mode shows no user notice."""
from tests.factories import text_query
from langbot_plugin.api.entities.builtin.provider.message import Message
chat = get_chat_handler()
mock_event_ctx.is_prevented_default.return_value = False
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
query = text_query('hide test')
query.adapter = Mock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=False)
query.user_message = Message(role='user', content=[])
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
}
class HideErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise RuntimeError('hidden')
yield
set_runner(HideErrorRunner)
handler = chat.ChatMessageHandler(fake_app)
results = []
async for result in handler.handle(query):
results.append(result)
assert results[0].user_notice is None
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatHandlerHelper:
"""Tests for helper methods."""
def test_cut_str_short(self, fake_app):
"""cut_str returns short string unchanged."""
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('short text')
assert result == 'short text'
def test_cut_str_long(self, fake_app):
"""cut_str truncates long string."""
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('this is a very long string that exceeds twenty characters')
assert '...' in result
assert len(result) <= 23
def test_cut_str_multiline(self, fake_app):
"""cut_str truncates multiline string."""
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result

View File

@@ -1,308 +1,396 @@
"""
Unit tests for CommandHandler behavior patterns.
Unit tests for CommandHandler - REAL imports.
Tests cover command processing patterns:
- Command parsing and routing
- Event emission pattern
- Command manager interaction
- Privilege handling
Uses pattern-based testing to avoid circular import issues in source code.
Tests the actual CommandHandler class from production code.
Uses tests.utils.import_isolation to break circular import chain safely.
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from tests.factories import command_query
from tests.factories import FakeApp, command_query
class TestCommandParsingPattern:
"""Tests for command parsing logic."""
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
def test_command_text_extraction(self):
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain using isolated_sys_modules.
Chain: handler → core.app → pipeline.controller → http_controller → groups/plugins → taskmgr
Uses tests.utils.import_isolation for safe, reversible sys.modules manipulation.
"""
from tests.utils.import_isolation import (
isolated_sys_modules,
make_pipeline_handler_import_mocks,
get_handler_modules_to_clear,
)
mocks = make_pipeline_handler_import_mocks()
clear = get_handler_modules_to_clear('command')
with isolated_sys_modules(mocks=mocks, clear=clear):
yield
@pytest.fixture
def fake_app():
"""Create FakeApp instance."""
return FakeApp()
@pytest.fixture
def mock_event_ctx():
"""Create mock event context."""
ctx = Mock()
ctx.is_prevented_default = Mock(return_value=False)
ctx.event = Mock()
ctx.event.reply_message_chain = None
return ctx
@pytest.fixture
def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute(
text: str | None = 'ok',
error: str | None = None,
image_url: str | None = None,
image_base64: str | None = None,
file_url: str | None = None,
):
async def mock_execute(command_text, full_command_text, query, session):
ret = Mock()
ret.text = text
ret.error = error
ret.image_url = image_url
ret.image_base64 = image_base64
ret.file_url = file_url
yield ret
return mock_execute
return _create_execute
# ============== CACHED LAZY IMPORTS ==============
_command_handler_module = None
_entities_module = None
def get_command_handler():
"""Import CommandHandler after circular import chain is mocked."""
global _command_handler_module
if _command_handler_module is None:
from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module
def get_entities():
"""Import pipeline entities - uses real module."""
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal:
"""Tests for real CommandHandler class."""
@pytest.mark.asyncio
async def test_real_import_works(self):
"""Verify we can import the real handler class."""
command = get_command_handler()
assert hasattr(command, 'CommandHandler')
handler_cls = command.CommandHandler
assert handler_cls.__name__ == 'CommandHandler'
@pytest.mark.asyncio
async def test_handler_creation(self, fake_app):
"""CommandHandler can be instantiated."""
command = get_command_handler()
handler = command.CommandHandler(fake_app)
assert handler.ap is fake_app
@pytest.mark.asyncio
async def test_command_parsing_extracts_command_name(self, fake_app, mock_event_ctx):
"""Command text is extracted after prefix."""
# Simulate the parsing pattern from command handler
full_command_text = "/help arg1 arg2"
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Handler strips first character (prefix)
command_text = full_command_text.strip()[1:]
parts = command_text.split(' ')
executed_commands = []
async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text)
ret = Mock()
ret.text = 'ok'
ret.error = None
ret.image_url = None
ret.image_base64 = None
ret.file_url = None
yield ret
assert parts[0] == 'help'
assert parts[1:] == ['arg1', 'arg2']
fake_app.cmd_mgr.execute = track_execute
def test_empty_command_parts(self):
"""Empty command has no parts."""
full_command_text = "/"
handler = command.CommandHandler(fake_app)
query = command_query('help arg1 arg2')
command_text = full_command_text.strip()[1:]
parts = command_text.split(' ')
results = []
async for result in handler.handle(query):
results.append(result)
assert parts == ['']
assert executed_commands[0] == 'help arg1 arg2'
def test_single_command_no_args(self):
"""Single command has no arguments."""
full_command_text = "/status"
command_text = full_command_text.strip()[1:]
parts = command_text.split(' ')
assert parts == ['status']
class TestCommandEventCreation:
"""Tests for command event creation pattern."""
def test_event_type_by_launcher_type(self):
"""Event type differs for person/group."""
import langbot_plugin.api.entities.events as events
# Person command
person_event_class = events.PersonCommandSent
# Group command
group_event_class = events.GroupCommandSent
assert person_event_class is not None
assert group_event_class is not None
def test_event_fields_pattern(self):
"""Command event should have expected fields."""
@pytest.mark.asyncio
async def test_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Admin users get privilege level 2."""
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
launcher_type = LauncherTypes.PERSON.value
launcher_id = '12345'
sender_id = '12345'
command = 'help'
params = ['arg1', 'arg2']
is_admin = False
command = get_command_handler()
# Simulate event creation pattern
event_data = {
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'command': command,
'params': params,
'is_admin': is_admin,
}
fake_app.instance_config.data = {'admins': ['person_12345']}
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory()
assert event_data['command'] == 'help'
assert event_data['params'] == ['arg1', 'arg2']
handler = command.CommandHandler(fake_app)
query = command_query('status')
query.launcher_type = LauncherTypes.PERSON
query.launcher_id = 12345
results = []
async for result in handler.handle(query):
results.append(result)
class TestPrivilegeCheckPattern:
"""Tests for privilege/admin check."""
def test_admin_check_by_session_id(self):
"""Admin is checked by session_id format."""
admins = ['person_12345', 'group_99999']
launcher_type = 'person'
launcher_id = '12345'
session_id = f'{launcher_type}_{launcher_id}'
is_admin = session_id in admins
assert is_admin is True
def test_non_admin_check(self):
"""Non-admin user has privilege 1."""
admins = ['person_12345']
launcher_type = 'person'
launcher_id = '67890'
session_id = f'{launcher_type}_{launcher_id}'
is_admin = session_id in admins
assert is_admin is False
def test_privilege_levels(self):
"""Privilege level 1 for normal, 2 for admin."""
normal_privilege = 1
admin_privilege = 2
admins = ['person_12345']
# Normal user
session_id = 'person_67890'
privilege = 2 if session_id in admins else 1
assert privilege == normal_privilege
# Admin user
session_id = 'person_12345'
privilege = 2 if session_id in admins else 1
assert privilege == admin_privilege
class TestCommandResultHandling:
"""Tests for command result handling patterns."""
call_args = fake_app.plugin_connector.emit_event.call_args
event = call_args[0][0]
assert event.is_admin is True
@pytest.mark.asyncio
async def test_text_result_pattern(self):
"""Text result is converted to message."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
async def test_non_admin_privilege_check(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Non-admin users get privilege level 1."""
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
# Simulate command return
ret = Mock()
ret.text = 'Command output'
ret.error = None
ret.image_url = None
ret.image_base64 = None
ret.file_url = None
command = get_command_handler()
# Pattern from handler: build content list
content = []
if ret.text is not None:
content.append(provider_message.ContentElement.from_text(ret.text))
fake_app.instance_config.data = {'admins': ['person_12345']}
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory()
assert len(content) == 1
assert content[0].type == 'text'
assert content[0].text == 'Command output'
handler = command.CommandHandler(fake_app)
query = command_query('status')
query.launcher_type = LauncherTypes.PERSON
query.launcher_id = 67890
results = []
async for result in handler.handle(query):
results.append(result)
call_args = fake_app.plugin_connector.emit_event.call_args
event = call_args[0][0]
assert event.is_admin is False
@pytest.mark.asyncio
async def test_error_result_pattern(self):
async def test_prevent_default_with_reply_continues(self, fake_app, mock_event_ctx):
"""prevent_default with reply yields CONTINUE."""
from tests.factories.message import text_chain
command = get_command_handler()
entities = get_entities()
reply_chain = text_chain('plugin reply')
mock_event_ctx.is_prevented_default.return_value = True
mock_event_ctx.event.reply_message_chain = reply_chain
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
handler = command.CommandHandler(fake_app)
query = command_query('test')
query.resp_messages = []
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(query.resp_messages) == 1
assert query.resp_messages[0] == reply_chain
@pytest.mark.asyncio
async def test_prevent_default_without_reply_interrupts(self, fake_app, mock_event_ctx):
"""prevent_default without reply yields INTERRUPT."""
command = get_command_handler()
entities = get_entities()
mock_event_ctx.is_prevented_default.return_value = True
mock_event_ctx.event.reply_message_chain = None
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
handler = command.CommandHandler(fake_app)
query = command_query('test')
results = []
async for result in handler.handle(query):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_event_type_person_command(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Person launcher creates PersonCommandSent event."""
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
from langbot_plugin.api.entities import events
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory()
handler = command.CommandHandler(fake_app)
query = command_query('help')
query.launcher_type = LauncherTypes.PERSON
results = []
async for result in handler.handle(query):
results.append(result)
call_args = fake_app.plugin_connector.emit_event.call_args
event = call_args[0][0]
assert isinstance(event, events.PersonCommandSent)
@pytest.mark.asyncio
async def test_event_type_group_command(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Group launcher creates GroupCommandSent event."""
from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes
from langbot_plugin.api.entities import events
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory()
handler = command.CommandHandler(fake_app)
query = command_query('help')
query.launcher_type = LauncherTypes.GROUP
results = []
async for result in handler.handle(query):
results.append(result)
call_args = fake_app.plugin_connector.emit_event.call_args
event = call_args[0][0]
assert isinstance(event, events.GroupCommandSent)
@pytest.mark.asyncio
async def test_command_result_text(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Text result is added to resp_messages."""
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(text='Command output')
handler = command.CommandHandler(fake_app)
query = command_query('echo')
query.resp_messages = []
results = []
async for result in handler.handle(query):
results.append(result)
assert len(query.resp_messages) == 1
msg = query.resp_messages[0]
assert msg.role == 'command'
assert len(msg.content) == 1
assert msg.content[0].type == 'text'
assert msg.content[0].text == 'Command output'
@pytest.mark.asyncio
async def test_command_result_error(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Error result creates error message."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(text=None, error='Command failed')
ret = Mock()
ret.text = None
ret.error = 'Command failed'
handler = command.CommandHandler(fake_app)
query = command_query('fail')
query.resp_messages = []
# Error handling pattern
if ret.error is not None:
msg = provider_message.Message(
role='command',
content=str(ret.error),
)
results = []
async for result in handler.handle(query):
results.append(result)
assert len(query.resp_messages) == 1
msg = query.resp_messages[0]
assert msg.role == 'command'
assert msg.content == 'Command failed'
@pytest.mark.asyncio
async def test_image_result_pattern(self):
"""Image result is added to content."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
async def test_command_result_image_url(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Image URL result is added to content."""
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:',
image_url='https://example.com/image.png'
)
ret = Mock()
ret.text = 'Here is the image:'
ret.error = None
ret.image_url = 'https://example.com/image.png'
ret.image_base64 = None
ret.file_url = None
content = []
if ret.text is not None:
content.append(provider_message.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
assert len(content) == 2
assert content[0].type == 'text'
assert content[1].type == 'image_url'
class TestPreventDefaultHandling:
"""Tests for prevent_default handling."""
@pytest.mark.asyncio
async def test_prevent_default_with_reply(self):
"""prevent_default with reply continues pipeline."""
from tests.factories.message import text_chain
# Simulate event context
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = True
event_ctx.event = Mock()
event_ctx.event.reply_message_chain = text_chain('plugin reply')
query = command_query('test')
handler = command.CommandHandler(fake_app)
query = command_query('image')
query.resp_messages = []
# Pattern from handler
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
query.resp_messages.append(event_ctx.event.reply_message_chain)
# yield CONTINUE
else:
# yield INTERRUPT
pass
results = []
async for result in handler.handle(query):
results.append(result)
assert len(query.resp_messages) == 1
msg = query.resp_messages[0]
assert len(msg.content) == 2
assert msg.content[0].type == 'text'
assert msg.content[1].type == 'image_url'
@pytest.mark.asyncio
async def test_prevent_default_without_reply(self):
"""prevent_default without reply interrupts."""
event_ctx = Mock()
event_ctx.is_prevented_default.return_value = True
event_ctx.event = Mock()
event_ctx.event.reply_message_chain = None
async def test_command_result_empty_interrupts(self, fake_app, mock_event_ctx, mock_execute_factory):
"""Empty result yields INTERRUPT."""
command = get_command_handler()
entities = get_entities()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(text=None)
query = command_query('test')
query.resp_messages = []
handler = command.CommandHandler(fake_app)
query = command_query('empty')
should_interrupt = False
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is None:
should_interrupt = True
results = []
async for result in handler.handle(query):
results.append(result)
assert should_interrupt is True
assert results[0].result_type == entities.ResultType.INTERRUPT
class TestStringTruncationHelper:
"""Tests for cut_str helper method."""
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerHelper:
"""Tests for helper methods."""
def test_short_string_no_change(self):
"""Short string is not truncated."""
# Pattern from handler.cut_str
def cut_str(s: str) -> str:
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0
result = cut_str('short text')
def test_cut_str_short(self, fake_app):
"""cut_str returns short string unchanged."""
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('short text')
assert result == 'short text'
def test_long_string_truncated(self):
"""Long string is truncated."""
def cut_str(s: str) -> str:
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0
result = cut_str('this is a very long string that exceeds twenty characters')
def test_cut_str_long(self, fake_app):
"""cut_str truncates long string."""
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('this is a very long string that exceeds twenty characters')
assert '...' in result
assert len(result) <= 23
def test_multiline_truncated(self):
"""Multiline string is truncated."""
def cut_str(s: str) -> str:
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0
result = cut_str('first line\nsecond line\nthird')
assert '...' in result
class TestCommandPrefixConfiguration:
"""Tests for command prefix configuration."""
def test_default_prefixes(self):
"""Default prefixes are slash and exclamation."""
default_prefixes = ['/', '!']
assert '/' in default_prefixes
assert '!' in default_prefixes
def test_custom_prefix(self):
"""Custom prefix can be configured."""
custom_prefix = '#'
full_text = f'{custom_prefix}help'
# Would be checked against config['command']['prefix']
is_command = full_text.startswith(custom_prefix)
assert is_command is True
def test_cut_str_multiline(self, fake_app):
"""cut_str truncates multiline string."""
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result

View File

@@ -0,0 +1,162 @@
"""
sys.modules isolation utilities for breaking circular import chains.
Provides safe, reversible sys.modules manipulation for tests that need to
import modules with heavy import-time side effects (auto-registration,
circular dependencies, etc.).
Usage pattern:
1. Create mock objects for modules that cause circular imports
2. Use isolated_sys_modules to temporarily patch sys.modules
3. Import target module after patching
4. Test the real production code
5. Context manager automatically restores original sys.modules state
Key principle: mock only what breaks the import chain, not what the code needs.
"""
from __future__ import annotations
import sys
import enum
from contextlib import contextmanager
from typing import Generator
from unittest.mock import MagicMock
class MockLifecycleControlScope(enum.Enum):
"""Mock enum for breaking circular import in core.entities."""
APPLICATION = 'application'
PLATFORM = 'platform'
PLUGIN = 'plugin'
PROVIDER = 'provider'
@contextmanager
def isolated_sys_modules(
mocks: dict[str, object],
clear: list[str] | None = None,
) -> Generator[None, None, None]:
"""
Context manager for isolated sys.modules manipulation.
Safely patches sys.modules with mocks and clears specified modules,
then restores original state on exit. This prevents test pollution
where mocks leak into subsequent tests.
Args:
mocks: Dict mapping module names to mock objects.
These will be set in sys.modules during the context.
clear: List of module names to remove from sys.modules before
entering the context. Useful for forcing re-import of
modules that depend on mocked modules.
Example:
>>> with isolated_sys_modules(
... mocks={'my_pkg.heavy_module': MagicMock()},
... clear=['my_pkg.target_module'],
... ):
... from my_pkg.target_module import MyClass # Safe import
Note:
- Modules in both mocks and clear will be mocked (not cleared)
- Original state is restored even if exception occurs
- Modules not in sys.modules before context are removed after
"""
clear = clear or []
touched = set(mocks.keys()) | set(clear)
# Save original state for modules we'll touch
saved: dict[str, object] = {}
for name in touched:
if name in sys.modules:
saved[name] = sys.modules[name]
try:
# Clear modules first (force re-import)
for name in clear:
if name not in mocks: # Don't clear if we're mocking it
sys.modules.pop(name, None)
# Apply mocks
for name, module in mocks.items():
sys.modules[name] = module
yield
finally:
# Restore original state - critical for test isolation
for name in touched:
if name in saved:
sys.modules[name] = saved[name]
else:
# Wasn't in sys.modules originally, remove it
sys.modules.pop(name, None)
def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
"""
Create mock objects needed to break circular import chain in handlers.
The import chain:
handler → core.app → pipeline.controller → http_controller
→ groups/plugins → taskmgr (partial init)
This function creates minimal mocks that break this chain without
affecting the handler's ability to use real pipeline.entities
(needed for ResultType enum comparisons).
Returns:
Dict mapping module names to MagicMock objects.
Note:
These mocks are intentionally minimal - they only provide what's
needed to prevent circular imports. The actual handler code uses
real imports from langbot_plugin.api and langbot.pkg.pipeline.entities.
"""
# Mock core.entities with proper Enum class
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
# Mock core.app - Application class is referenced but not instantiated
mock_app = MagicMock()
# Mock provider.runner - has preregistered_runners attribute
mock_runner = MagicMock()
mock_runner.preregistered_runners = [] # Empty by default, tests override
# Mock utils.importutil - prevents auto-import of runners
mock_importutil = MagicMock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
return {
'langbot.pkg.core.entities': mock_entities,
'langbot.pkg.core.app': mock_app,
'langbot.pkg.pipeline.controller': MagicMock(),
'langbot.pkg.pipeline.pipelinemgr': MagicMock(),
'langbot.pkg.pipeline.process.process': MagicMock(),
'langbot.pkg.provider.runner': mock_runner,
'langbot.pkg.utils.importutil': mock_importutil,
}
def get_handler_modules_to_clear(handler_name: str) -> list[str]:
"""
Get list of handler-related modules to clear before import.
These modules need to be cleared so they're re-imported after
the circular import chain is mocked. Without clearing, they'd
already be in sys.modules (possibly partially initialized).
Args:
handler_name: The handler file name (e.g., 'chat', 'command')
Returns:
List of module names to clear.
"""
return [
'langbot.pkg.pipeline.process.handler',
'langbot.pkg.pipeline.process.handlers',
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
]