mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-25 23:14:20 +00:00
test(phase2): add unit tests for core, persistence, plugin, utils
- Add test_handler_helpers.py for plugin handler helpers (7 tests) - Add test_mgr_methods.py for persistence manager (5 tests) - Add test_app_config_validation.py for core app config (12 tests) - Add test_knowledge_service.py for API knowledge service (22 tests) - Add test_kbmgr.py for RAG knowledge base manager (39 tests) - Add test_survey_manager.py for survey manager (22 tests) - Add test_connector_methods.py for plugin connector (24 tests) - Add test_funcschema.py for utils function schema (9 tests) - Add test_platform.py for utils platform detection (7 tests) - Add test_extract_deps.py for plugin deps extraction (7 tests) - Add test_database_decorator.py for persistence decorator (7 tests) - Add test_load_config.py for core config loading (19 tests) - Add COVERAGE_EXCLUSIONS.md documenting external adapter exclusions - Fix test_chat_session_limit.py path for portability Coverage: core 28% → 30%, persistence 24% → 24.4%, plugin 27% → 28% Total: 1082 tests passed, core module coverage 45.5% Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,192 @@
|
||||
"""Unit tests for core app config validation methods.
|
||||
|
||||
Tests cover:
|
||||
- _get_positive_int_config() validation
|
||||
- _get_positive_float_config() validation
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def get_app_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
return import_module('langbot.pkg.core.app')
|
||||
|
||||
|
||||
class TestGetPositiveIntConfig:
|
||||
"""Tests for _get_positive_int_config method."""
|
||||
|
||||
def test_returns_value_when_valid_positive_int(self):
|
||||
"""Test returns parsed int for valid positive value."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config(10, default=30, name='test.config')
|
||||
|
||||
assert result == 10
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_returns_value_when_valid_string_int(self):
|
||||
"""Test returns parsed int for string value."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config('50', default=30, name='test.config')
|
||||
|
||||
assert result == 50
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_returns_default_for_zero(self):
|
||||
"""Test returns default when value is zero."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config(0, default=30, name='test.config')
|
||||
|
||||
assert result == 30
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_returns_default_for_negative(self):
|
||||
"""Test returns default when value is negative."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config(-5, default=30, name='test.config')
|
||||
|
||||
assert result == 30
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_returns_default_for_invalid_string(self):
|
||||
"""Test returns default when value is invalid string."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config('invalid', default=30, name='test.config')
|
||||
|
||||
assert result == 30
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_returns_default_for_none(self):
|
||||
"""Test returns default when value is None."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_int_config(None, default=30, name='test.config')
|
||||
|
||||
assert result == 30
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
|
||||
class TestGetPositiveFloatConfig:
|
||||
"""Tests for _get_positive_float_config method."""
|
||||
|
||||
def test_returns_value_when_valid_positive_float(self):
|
||||
"""Test returns parsed float for valid positive value."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config(1.5, default=2.0, name='test.config')
|
||||
|
||||
assert result == 1.5
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_returns_value_when_valid_int(self):
|
||||
"""Test returns float for valid int value."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config(2, default=1.0, name='test.config')
|
||||
|
||||
assert result == 2.0
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_returns_value_when_valid_string_float(self):
|
||||
"""Test returns parsed float for string value."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config('0.5', default=1.0, name='test.config')
|
||||
|
||||
assert result == 0.5
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_returns_default_for_zero(self):
|
||||
"""Test returns default when value is zero."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config(0.0, default=1.0, name='test.config')
|
||||
|
||||
assert result == 1.0
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_returns_default_for_negative(self):
|
||||
"""Test returns default when value is negative."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config(-1.0, default=2.0, name='test.config')
|
||||
|
||||
assert result == 2.0
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_returns_default_for_invalid_string(self):
|
||||
"""Test returns default when value is invalid string."""
|
||||
app_module = get_app_module()
|
||||
|
||||
mock_logger = Mock()
|
||||
|
||||
app = app_module.Application()
|
||||
app.logger = mock_logger
|
||||
|
||||
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
|
||||
|
||||
assert result == 1.5
|
||||
mock_logger.warning.assert_called_once()
|
||||
@@ -0,0 +1,266 @@
|
||||
"""Unit tests for core stages load_config _apply_env_overrides_to_config.
|
||||
|
||||
Tests cover:
|
||||
- Environment variable parsing and path conversion
|
||||
- Type conversion (bool, int, float, string)
|
||||
- List handling (comma-separated)
|
||||
- Dict type skipping
|
||||
- Missing key creation
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def get_load_config_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
return import_module('langbot.pkg.core.stages.load_config')
|
||||
|
||||
|
||||
class TestApplyEnvOverridesToConfig:
|
||||
"""Tests for _apply_env_overrides_to_config function."""
|
||||
|
||||
def test_override_string_value(self):
|
||||
"""Test overriding an existing string config value."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'name': 'default'}}
|
||||
env = {'SYSTEM__NAME': 'custom_name'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['name'] == 'custom_name'
|
||||
|
||||
def test_override_int_value(self):
|
||||
"""Test overriding an int value with proper conversion."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'concurrency': {'pipeline': 5}}
|
||||
env = {'CONCURRENCY__PIPELINE': '10'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['concurrency']['pipeline'] == 10
|
||||
assert isinstance(result['concurrency']['pipeline'], int)
|
||||
|
||||
def test_override_int_value_invalid_conversion(self):
|
||||
"""Test that invalid int conversion keeps string value."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'concurrency': {'pipeline': 5}}
|
||||
env = {'CONCURRENCY__PIPELINE': 'not_a_number'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
# Falls back to string when conversion fails
|
||||
assert result['concurrency']['pipeline'] == 'not_a_number'
|
||||
|
||||
def test_override_bool_value_true(self):
|
||||
"""Test overriding bool value with 'true' string."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'enable': False}}
|
||||
env = {'SYSTEM__ENABLE': 'true'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['enable'] is True
|
||||
|
||||
def test_override_bool_value_false(self):
|
||||
"""Test overriding bool value with 'false' string."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'enable': True}}
|
||||
env = {'SYSTEM__ENABLE': 'false'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['enable'] is False
|
||||
|
||||
def test_override_bool_value_various_true_forms(self):
|
||||
"""Test that '1', 'yes', 'on' are treated as true."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'flag': False}}
|
||||
|
||||
for true_val in ['1', 'yes', 'on', 'TRUE']:
|
||||
env = {'SYSTEM__FLAG': true_val}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg.copy())
|
||||
assert result['system']['flag'] is True
|
||||
|
||||
def test_override_float_value(self):
|
||||
"""Test overriding float value with proper conversion."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'timeout': 1.5}}
|
||||
env = {'SYSTEM__TIMEOUT': '2.5'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['timeout'] == 2.5
|
||||
assert isinstance(result['system']['timeout'], float)
|
||||
|
||||
def test_override_list_value(self):
|
||||
"""Test that comma-separated string converts to list."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'disabled_adapters': ['adapter1']}}
|
||||
env = {'SYSTEM__DISABLED_ADAPTERS': 'aiocqhttp,dingtalk,telegram'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['disabled_adapters'] == ['aiocqhttp', 'dingtalk', 'telegram']
|
||||
|
||||
def test_override_list_value_empty_items(self):
|
||||
"""Test that empty items in comma-separated list are filtered."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'disabled_adapters': []}}
|
||||
env = {'SYSTEM__DISABLED_ADAPTERS': 'a,,b,,,c'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
# Empty items should be filtered out
|
||||
assert result['system']['disabled_adapters'] == ['a', 'b', 'c']
|
||||
|
||||
def test_skip_dict_type_override(self):
|
||||
"""Test that dict type values are skipped."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'plugin': {'settings': {'nested': 'value'}}}
|
||||
env = {'PLUGIN__SETTINGS': 'should_not_apply'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
# Dict type should not be overridden
|
||||
assert result['plugin']['settings'] == {'nested': 'value'}
|
||||
|
||||
def test_create_new_key_when_missing(self):
|
||||
"""Test that missing keys are created as strings."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {}}
|
||||
env = {'SYSTEM__NEW_KEY': 'new_value'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['new_key'] == 'new_value'
|
||||
|
||||
def test_create_nested_path(self):
|
||||
"""Test that intermediate dict is created for nested path."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {}
|
||||
env = {'NEW__SECTION__KEY': 'value'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['new']['section']['key'] == 'value'
|
||||
|
||||
def test_skip_non_uppercase_env_vars(self):
|
||||
"""Test that non-uppercase env vars are skipped."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'name': 'default'}}
|
||||
env = {'system__name': 'should_not_apply'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['name'] == 'default'
|
||||
|
||||
def test_skip_env_vars_without_double_underscore(self):
|
||||
"""Test that env vars without __ are skipped."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'name': 'default'}}
|
||||
env = {'SYSTEMNAME': 'should_not_apply'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['name'] == 'default'
|
||||
|
||||
def test_nested_config_path(self):
|
||||
"""Test overriding deeply nested config."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'level1': {'level2': {'level3': 'original'}}}
|
||||
env = {'LEVEL1__LEVEL2__LEVEL3': 'overridden'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['level1']['level2']['level3'] == 'overridden'
|
||||
|
||||
def test_non_dict_current_breaks(self):
|
||||
"""Test that path navigation stops when current is not dict."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': 'not_a_dict'}
|
||||
env = {'SYSTEM__NAME': 'should_not_apply'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
# Should remain unchanged since 'system' is not a dict
|
||||
assert result == {'system': 'not_a_dict'}
|
||||
|
||||
def test_empty_config(self):
|
||||
"""Test that empty config dict is handled."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {}
|
||||
env = {'SOME__KEY': 'value'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['some']['key'] == 'value'
|
||||
|
||||
def test_no_matching_env_vars(self):
|
||||
"""Test that config is unchanged when no matching env vars."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {'system': {'name': 'default'}}
|
||||
env = {'OTHER_VAR': 'value'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result == cfg
|
||||
|
||||
def test_multiple_env_vars_override(self):
|
||||
"""Test multiple env vars applied in order."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {
|
||||
'system': {'name': 'default', 'enable': True},
|
||||
'concurrency': {'pipeline': 5}
|
||||
}
|
||||
env = {
|
||||
'SYSTEM__NAME': 'custom',
|
||||
'SYSTEM__ENABLE': 'false',
|
||||
'CONCURRENCY__PIPELINE': '10'
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['system']['name'] == 'custom'
|
||||
assert result['system']['enable'] is False
|
||||
assert result['concurrency']['pipeline'] == 10
|
||||
@@ -1,524 +1,506 @@
|
||||
"""Unit tests for core TaskContext, TaskWrapper, and AsyncTaskManager.
|
||||
|
||||
Tests cover:
|
||||
- TaskContext initialization, state tracking, serialization
|
||||
- TaskWrapper ID generation, to_dict serialization
|
||||
- AsyncTaskManager task creation, stats, pruning
|
||||
|
||||
Note: Uses import_isolation to break circular import chains.
|
||||
"""
|
||||
Unit tests for AsyncTaskManager and TaskWrapper.
|
||||
|
||||
Tests cover async task lifecycle management:
|
||||
- Task scheduling and tracking
|
||||
- Task completion
|
||||
- Task exception handling
|
||||
- Task cancellation
|
||||
- Multiple task isolation
|
||||
|
||||
Uses module pre-mocking to break circular import chain.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import sys
|
||||
import enum
|
||||
from unittest.mock import MagicMock
|
||||
from importlib import import_module
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
|
||||
# Pre-mock app module BEFORE importing taskmgr to break circular chain:
|
||||
# taskmgr → app → http_controller → groups/knowledge/migration → taskmgr (partial)
|
||||
class FakeMinimalApp:
|
||||
"""Minimal app that only provides event_loop."""
|
||||
class MockLifecycleControlScopeEnum:
|
||||
"""Mock enum value for LifecycleControlScope with .value attribute."""
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def __init__(self, event_loop):
|
||||
self.event_loop = event_loop
|
||||
self.instance_config = MagicMock()
|
||||
self.instance_config.data = {}
|
||||
|
||||
# Pre-register mock app module
|
||||
_mock_app_module = MagicMock()
|
||||
_mock_app_module.Application = FakeMinimalApp
|
||||
sys.modules['langbot.pkg.core.app'] = _mock_app_module
|
||||
|
||||
# Pre-register mock entities module - use proper Enum
|
||||
class LifecycleControlScope(enum.Enum):
|
||||
APPLICATION = 'application'
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
_mock_entities_module = MagicMock()
|
||||
_mock_entities_module.LifecycleControlScope = LifecycleControlScope
|
||||
sys.modules['langbot.pkg.core.entities'] = _mock_entities_module
|
||||
def __repr__(self):
|
||||
return f"LifecycleControlScope.{self.value.upper()}"
|
||||
|
||||
|
||||
def get_taskmgr():
|
||||
"""Import taskmgr after pre-mocking."""
|
||||
return import_module('langbot.pkg.core.taskmgr')
|
||||
class MockLifecycleControlScope:
|
||||
"""Mock enum for LifecycleControlScope."""
|
||||
APPLICATION = MockLifecycleControlScopeEnum('application')
|
||||
PLATFORM = MockLifecycleControlScopeEnum('platform')
|
||||
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
|
||||
PLUGIN = MockLifecycleControlScopeEnum('plugin')
|
||||
|
||||
|
||||
def get_entities():
|
||||
"""Get pre-registered mock entities module."""
|
||||
return sys.modules['langbot.pkg.core.entities']
|
||||
@contextmanager
|
||||
def isolated_taskmgr_import() -> Generator[None, None, None]:
|
||||
"""Context manager to isolate circular imports for taskmgr testing."""
|
||||
# Mock modules that cause circular imports
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
mock_importutil = MagicMock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
||||
|
||||
mock_http_controller = MagicMock()
|
||||
|
||||
mock_rag_mgr = MagicMock()
|
||||
|
||||
mocks = {
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
'langbot.pkg.api.http.controller.main': mock_http_controller,
|
||||
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
}
|
||||
|
||||
# Save original state
|
||||
saved = {}
|
||||
for name in mocks:
|
||||
if name in sys.modules:
|
||||
saved[name] = sys.modules[name]
|
||||
|
||||
# Clear taskmgr to force re-import
|
||||
taskmgr_name = 'langbot.pkg.core.taskmgr'
|
||||
if taskmgr_name in sys.modules:
|
||||
saved[taskmgr_name] = sys.modules[taskmgr_name]
|
||||
|
||||
try:
|
||||
# Apply mocks
|
||||
for name, module in mocks.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
# Clear taskmgr
|
||||
sys.modules.pop(taskmgr_name, None)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore
|
||||
for name in mocks:
|
||||
if name in saved:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
if taskmgr_name in saved:
|
||||
sys.modules[taskmgr_name] = saved[taskmgr_name]
|
||||
else:
|
||||
sys.modules.pop(taskmgr_name, None)
|
||||
|
||||
|
||||
class TestTaskContextReal:
|
||||
"""Tests for real TaskContext class (no circular import)."""
|
||||
def get_taskmgr_classes():
|
||||
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
|
||||
with isolated_taskmgr_import():
|
||||
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
|
||||
return TaskContext, TaskWrapper, AsyncTaskManager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_new(self):
|
||||
"""TaskContext.new() creates instance."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
def create_mock_app():
|
||||
"""Create a mock Application for testing."""
|
||||
mock_app = Mock()
|
||||
mock_app.event_loop = asyncio.get_running_loop()
|
||||
mock_app.instance_config = Mock()
|
||||
mock_app.instance_config.data = {
|
||||
'system': {
|
||||
'task_retention': {
|
||||
'completed_limit': 200,
|
||||
}
|
||||
}
|
||||
}
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestTaskContext:
|
||||
"""Tests for TaskContext class."""
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test that TaskContext initializes with default values."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
assert ctx.current_action == 'default'
|
||||
assert ctx.log == ''
|
||||
assert ctx.metadata == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_trace(self):
|
||||
"""TaskContext.trace adds formatted log."""
|
||||
taskmgr = get_taskmgr()
|
||||
def test_set_current_action(self):
|
||||
"""Test setting current action."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.trace('test message', action='test_action')
|
||||
ctx.set_current_action('installing_plugin')
|
||||
assert ctx.current_action == 'installing_plugin'
|
||||
|
||||
assert ctx.current_action == 'test_action'
|
||||
assert 'test message' in ctx.log
|
||||
assert 'test_action' in ctx.log
|
||||
# Contains timestamp format
|
||||
assert '|' in ctx.log
|
||||
def test_trace_without_action(self):
|
||||
"""Test trace method without action override."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_multiple_traces(self):
|
||||
"""TaskContext accumulates multiple traces."""
|
||||
taskmgr = get_taskmgr()
|
||||
ctx.trace('Starting process')
|
||||
assert 'Starting process' in ctx.log
|
||||
assert ctx.current_action == 'default'
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.trace('first')
|
||||
ctx.trace('second')
|
||||
def test_trace_with_action_override(self):
|
||||
"""Test trace method with action override."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
assert 'first' in ctx.log
|
||||
assert 'second' in ctx.log
|
||||
ctx.trace('Downloading', action='download')
|
||||
assert 'Downloading' in ctx.log
|
||||
assert ctx.current_action == 'download'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_to_dict(self):
|
||||
"""TaskContext.to_dict returns all fields."""
|
||||
taskmgr = get_taskmgr()
|
||||
def test_trace_accumulates_logs(self):
|
||||
"""Test that trace accumulates log entries."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.trace('log entry')
|
||||
ctx.trace('Step 1')
|
||||
ctx.trace('Step 2')
|
||||
ctx.trace('Step 3')
|
||||
|
||||
assert 'Step 1' in ctx.log
|
||||
assert 'Step 2' in ctx.log
|
||||
assert 'Step 3' in ctx.log
|
||||
# Each trace adds a newline
|
||||
assert ctx.log.count('\n') == 3
|
||||
|
||||
def test_to_dict_serialization(self):
|
||||
"""Test to_dict serialization."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
ctx.set_current_action('test_action')
|
||||
ctx.trace('Test message')
|
||||
ctx.metadata['key'] = 'value'
|
||||
|
||||
result = ctx.to_dict()
|
||||
|
||||
assert 'current_action' in result
|
||||
assert 'log' in result
|
||||
assert 'metadata' in result
|
||||
assert result['log'] == ctx.log
|
||||
assert result['current_action'] == 'test_action'
|
||||
assert 'Test message' in result['log']
|
||||
assert result['metadata'] == {'key': 'value'}
|
||||
|
||||
def test_static_new_factory(self):
|
||||
"""Test TaskContext.new() factory method."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext.new()
|
||||
|
||||
assert isinstance(ctx, TaskContext)
|
||||
assert ctx.current_action == 'default'
|
||||
|
||||
def test_static_placeholder_singleton(self):
|
||||
"""Test TaskContext.placeholder() returns singleton."""
|
||||
with isolated_taskmgr_import():
|
||||
from langbot.pkg.core.taskmgr import TaskContext
|
||||
|
||||
# Reset global placeholder
|
||||
import langbot.pkg.core.taskmgr as taskmgr_module
|
||||
taskmgr_module.placeholder_context = None
|
||||
|
||||
ctx1 = TaskContext.placeholder()
|
||||
ctx2 = TaskContext.placeholder()
|
||||
|
||||
assert ctx1 is ctx2
|
||||
|
||||
def test_metadata_is_mutable_dict(self):
|
||||
"""Test that metadata is a mutable dict."""
|
||||
TaskContext, _, _ = get_taskmgr_classes()
|
||||
ctx = TaskContext()
|
||||
|
||||
ctx.metadata['count'] = 5
|
||||
ctx.metadata['items'] = ['a', 'b', 'c']
|
||||
|
||||
assert ctx.metadata['count'] == 5
|
||||
assert len(ctx.metadata['items']) == 3
|
||||
|
||||
|
||||
class TestTaskWrapper:
|
||||
"""Tests for TaskWrapper class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_set_current_action(self):
|
||||
"""set_current_action updates action."""
|
||||
taskmgr = get_taskmgr()
|
||||
async def test_id_auto_increment(self):
|
||||
"""Test that task IDs auto-increment."""
|
||||
TaskContext, TaskWrapper, _ = get_taskmgr_classes()
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.set_current_action('new_action')
|
||||
# Reset ID index
|
||||
TaskWrapper._id_index = 0
|
||||
|
||||
assert ctx.current_action == 'new_action'
|
||||
mock_app = create_mock_app()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_context_metadata(self):
|
||||
"""TaskContext metadata can be set."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.metadata['key'] = 'value'
|
||||
|
||||
assert ctx.metadata['key'] == 'value'
|
||||
assert ctx.to_dict()['metadata']['key'] == 'value'
|
||||
|
||||
def test_task_context_placeholder_singleton(self):
|
||||
"""placeholder returns same instance."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
ctx1 = taskmgr.TaskContext.placeholder()
|
||||
ctx2 = taskmgr.TaskContext.placeholder()
|
||||
|
||||
assert ctx1 is ctx2
|
||||
|
||||
|
||||
class TestTaskWrapperReal:
|
||||
"""Tests for real TaskWrapper class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_creates_task(self):
|
||||
"""TaskWrapper creates and wraps asyncio.Task."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
async def simple_coro():
|
||||
return 42
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, simple_coro(), name='test')
|
||||
|
||||
assert wrapper.name == 'test'
|
||||
assert wrapper.task is not None
|
||||
assert isinstance(wrapper.task, asyncio.Task)
|
||||
|
||||
result = await wrapper.task
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_with_custom_context(self):
|
||||
"""TaskWrapper uses provided TaskContext."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
ctx.set_current_action('custom')
|
||||
|
||||
async def coro():
|
||||
async def dummy_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
return 'done'
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, coro(), context=ctx)
|
||||
wrapper1 = TaskWrapper(mock_app, dummy_coro())
|
||||
wrapper2 = TaskWrapper(mock_app, dummy_coro())
|
||||
|
||||
assert wrapper.task_context.current_action == 'custom'
|
||||
assert wrapper1.id == 0
|
||||
assert wrapper2.id == 1
|
||||
|
||||
await wrapper.task
|
||||
# Clean up
|
||||
wrapper1.cancel()
|
||||
wrapper2.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_exception_capture(self):
|
||||
"""TaskWrapper captures exception from failed task."""
|
||||
taskmgr = get_taskmgr()
|
||||
async def test_default_task_type_and_kind(self):
|
||||
"""Test default task_type and kind values."""
|
||||
_, TaskWrapper, _ = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
async def dummy_coro():
|
||||
return 'done'
|
||||
|
||||
wrapper = TaskWrapper(mock_app, dummy_coro())
|
||||
|
||||
assert wrapper.task_type == 'system'
|
||||
assert wrapper.kind == 'system_task'
|
||||
|
||||
wrapper.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict_serialization(self):
|
||||
"""Test TaskWrapper.to_dict serialization."""
|
||||
_, TaskWrapper, _ = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
async def immediate_coro():
|
||||
return 'result'
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
mock_app, immediate_coro(),
|
||||
name='test_task',
|
||||
label='Test Task',
|
||||
)
|
||||
|
||||
# Wait for task to complete
|
||||
await wrapper.task
|
||||
|
||||
result = wrapper.to_dict()
|
||||
|
||||
assert result['name'] == 'test_task'
|
||||
assert result['label'] == 'Test Task'
|
||||
assert result['task_type'] == 'system'
|
||||
assert result['runtime']['done'] == True
|
||||
assert result['runtime']['result'] == 'result'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_dict_with_exception(self):
|
||||
"""Test TaskWrapper.to_dict when task has exception."""
|
||||
_, TaskWrapper, _ = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
async def failing_coro():
|
||||
raise ValueError('test error')
|
||||
raise ValueError('Test error')
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, failing_coro())
|
||||
wrapper = TaskWrapper(mock_app, failing_coro())
|
||||
|
||||
# Let task complete with exception
|
||||
await asyncio.sleep(0.01)
|
||||
# Wait for task to complete
|
||||
try:
|
||||
await wrapper.task
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
exception = wrapper.assume_exception()
|
||||
assert exception is not None
|
||||
assert isinstance(exception, ValueError)
|
||||
assert 'test error' in str(exception)
|
||||
result = wrapper.to_dict()
|
||||
|
||||
assert result['runtime']['done'] == True
|
||||
assert result['runtime']['exception'] == 'Test error'
|
||||
assert 'exception_traceback' in result['runtime']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_result_capture(self):
|
||||
"""TaskWrapper captures result from completed task."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
async def coro():
|
||||
return 'result_value'
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, coro())
|
||||
|
||||
await wrapper.task
|
||||
|
||||
result = wrapper.assume_result()
|
||||
assert result == 'result_value'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_cancel(self):
|
||||
"""TaskWrapper.cancel cancels the task."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
async def test_cancel_task(self):
|
||||
"""Test cancel method cancels the asyncio task."""
|
||||
_, TaskWrapper, _ = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
async def long_coro():
|
||||
await asyncio.sleep(10)
|
||||
return 'done'
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, long_coro())
|
||||
wrapper = TaskWrapper(mock_app, long_coro())
|
||||
|
||||
# Task should be running
|
||||
assert not wrapper.task.done()
|
||||
|
||||
wrapper.cancel()
|
||||
|
||||
# Give it a moment to be cancelled
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert wrapper.task.cancelled() or wrapper.task.done()
|
||||
assert wrapper.task.done()
|
||||
assert wrapper.task.cancelled()
|
||||
|
||||
|
||||
class TestAsyncTaskManager:
|
||||
"""Tests for AsyncTaskManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_to_dict(self):
|
||||
"""TaskWrapper.to_dict serializes task info."""
|
||||
taskmgr = get_taskmgr()
|
||||
async def test_create_task_adds_to_list(self):
|
||||
"""Test that create_task adds task to tasks list."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
async def coro():
|
||||
return 42
|
||||
async def dummy_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
return 'done'
|
||||
|
||||
wrapper = taskmgr.TaskWrapper(app, coro(), name='dict_test', label='Test')
|
||||
|
||||
await wrapper.task
|
||||
|
||||
result = wrapper.to_dict()
|
||||
|
||||
assert result['name'] == 'dict_test'
|
||||
assert result['label'] == 'Test'
|
||||
assert 'runtime' in result
|
||||
assert result['runtime']['done'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_wrapper_id_increment(self):
|
||||
"""TaskWrapper IDs increment."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
async def coro():
|
||||
return 1
|
||||
|
||||
wrapper1 = taskmgr.TaskWrapper(app, coro())
|
||||
wrapper2 = taskmgr.TaskWrapper(app, coro())
|
||||
|
||||
assert wrapper2.id > wrapper1.id
|
||||
|
||||
|
||||
class TestAsyncTaskManagerReal:
|
||||
"""Tests for real AsyncTaskManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_create_task(self):
|
||||
"""AsyncTaskManager creates and tracks tasks."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def coro():
|
||||
return 'result'
|
||||
|
||||
wrapper = manager.create_task(coro(), name='test')
|
||||
wrapper = manager.create_task(dummy_coro())
|
||||
|
||||
assert wrapper in manager.tasks
|
||||
assert wrapper.name == 'test'
|
||||
assert len(manager.tasks) == 1
|
||||
|
||||
await wrapper.task
|
||||
wrapper.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_create_user_task(self):
|
||||
"""create_user_task creates user-type task."""
|
||||
taskmgr = get_taskmgr()
|
||||
async def test_get_stats_counts_correctly(self):
|
||||
"""Test get_stats returns correct counts."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
async def immediate_coro():
|
||||
return 'done'
|
||||
|
||||
async def coro():
|
||||
return 'user_result'
|
||||
async def delayed_coro():
|
||||
await asyncio.sleep(0.1)
|
||||
return 'done'
|
||||
|
||||
wrapper = manager.create_user_task(coro())
|
||||
# Create tasks
|
||||
w1 = manager.create_task(immediate_coro())
|
||||
w2 = manager.create_task(delayed_coro())
|
||||
|
||||
# Wait for first to complete
|
||||
await w1.task
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats['total'] == 2
|
||||
assert stats['completed'] == 1
|
||||
assert stats['running'] == 1
|
||||
|
||||
w2.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_dict_filters_by_type(self):
|
||||
"""Test get_tasks_dict filters by type."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
async def dummy_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Create system and user tasks
|
||||
w1 = manager.create_task(dummy_coro(), task_type='system')
|
||||
w2 = manager.create_task(dummy_coro(), task_type='user')
|
||||
w3 = manager.create_task(dummy_coro(), task_type='user')
|
||||
|
||||
result = manager.get_tasks_dict(type='user')
|
||||
|
||||
assert len(result['tasks']) == 2
|
||||
for t in result['tasks']:
|
||||
assert t['task_type'] == 'user'
|
||||
|
||||
w1.cancel()
|
||||
w2.cancel()
|
||||
w3.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_scope(self):
|
||||
"""Test cancel_by_scope cancels matching tasks."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
|
||||
mock_app = create_mock_app()
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
async def long_coro():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Create task with APPLICATION scope
|
||||
w1 = manager.create_task(
|
||||
long_coro(),
|
||||
scopes=[MockLifecycleControlScope.APPLICATION]
|
||||
)
|
||||
|
||||
# Create task with different scope
|
||||
w2 = manager.create_task(
|
||||
long_coro(),
|
||||
scopes=[MockLifecycleControlScope.PIPELINE]
|
||||
)
|
||||
|
||||
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert w1.task.cancelled() or w1.task.done()
|
||||
assert not w2.task.done()
|
||||
|
||||
w2.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_task_by_id(self):
|
||||
"""Test cancel_task cancels specific task by ID."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
async def long_coro():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
w1 = manager.create_task(long_coro())
|
||||
w2 = manager.create_task(long_coro())
|
||||
|
||||
manager.cancel_task(w1.id)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert w1.task.done()
|
||||
assert not w2.task.done()
|
||||
|
||||
w2.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_task_sets_user_type(self):
|
||||
"""Test create_user_task sets task_type to 'user'."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
async def dummy_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
wrapper = manager.create_user_task(dummy_coro())
|
||||
|
||||
assert wrapper.task_type == 'user'
|
||||
|
||||
await wrapper.task
|
||||
wrapper.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_multiple_tasks_isolated(self):
|
||||
"""Multiple tasks run independently."""
|
||||
taskmgr = get_taskmgr()
|
||||
async def test_get_task_by_id(self):
|
||||
"""Test get_task_by_id returns correct task."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
results = []
|
||||
|
||||
async def task_a():
|
||||
results.append('a')
|
||||
|
||||
async def task_b():
|
||||
results.append('b')
|
||||
|
||||
w1 = manager.create_task(task_a(), name='a')
|
||||
w2 = manager.create_task(task_b(), name='b')
|
||||
|
||||
await asyncio.gather(w1.task, w2.task)
|
||||
|
||||
assert 'a' in results
|
||||
assert 'b' in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_get_task_by_id(self):
|
||||
"""get_task_by_id finds task."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def coro():
|
||||
return 1
|
||||
|
||||
wrapper = manager.create_task(coro())
|
||||
|
||||
found = manager.get_task_by_id(wrapper.id)
|
||||
assert found is wrapper
|
||||
|
||||
not_found = manager.get_task_by_id(99999)
|
||||
assert not_found is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_cancel_task(self):
|
||||
"""cancel_task cancels specific task."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def long():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
wrapper = manager.create_task(long())
|
||||
|
||||
manager.cancel_task(wrapper.id)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
assert wrapper.task.cancelled() or wrapper.task.done()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_cancel_by_scope(self):
|
||||
"""cancel_by_scope cancels matching scope tasks."""
|
||||
taskmgr = get_taskmgr()
|
||||
entities = get_entities()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def long():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
async def app_long():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Create task with PLATFORM scope
|
||||
platform_wrapper = manager.create_task(
|
||||
long(),
|
||||
scopes=[entities.LifecycleControlScope.PLATFORM],
|
||||
)
|
||||
|
||||
# Create task with APPLICATION scope
|
||||
manager.create_task(
|
||||
app_long(),
|
||||
scopes=[entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
manager.cancel_by_scope(entities.LifecycleControlScope.PLATFORM)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Platform task cancelled
|
||||
assert platform_wrapper.task.cancelled() or platform_wrapper.task.done()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_get_stats(self):
|
||||
"""get_stats returns task counts."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def quick():
|
||||
return 1
|
||||
|
||||
for _ in range(3):
|
||||
w = manager.create_task(quick())
|
||||
await w.task
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats['total'] >= 3
|
||||
assert stats['completed'] >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_get_tasks_dict(self):
|
||||
"""get_tasks_dict filters by type."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def coro():
|
||||
return 1
|
||||
|
||||
system_w = manager.create_task(coro(), task_type='system')
|
||||
user_w = manager.create_user_task(coro())
|
||||
|
||||
await asyncio.gather(system_w.task, user_w.task)
|
||||
|
||||
system_tasks = manager.get_tasks_dict(type='system')
|
||||
assert all(t['task_type'] == 'system' for t in system_tasks['tasks'])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_wait_all(self):
|
||||
"""wait_all waits for all tasks."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def delayed():
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
for _ in range(3):
|
||||
manager.create_task(delayed())
|
||||
|
||||
await manager.wait_all()
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats['running'] == 0
|
||||
|
||||
|
||||
class TestTaskPruningReal:
|
||||
"""Tests for real task pruning behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prune_completed_tasks(self):
|
||||
"""Completed tasks are pruned when exceeding limit."""
|
||||
taskmgr = get_taskmgr()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
app = FakeMinimalApp(loop)
|
||||
app.instance_config.data = {'system': {'task_retention': {'completed_limit': 3}}}
|
||||
|
||||
manager = taskmgr.AsyncTaskManager(app)
|
||||
|
||||
async def quick():
|
||||
return 1
|
||||
|
||||
# Create more than limit
|
||||
for _ in range(5):
|
||||
w = manager.create_task(quick())
|
||||
await w.task
|
||||
async def dummy_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Completed count should be <= limit
|
||||
completed = sum(1 for w in manager.tasks if w.task.done())
|
||||
assert completed <= 3
|
||||
w1 = manager.create_task(dummy_coro())
|
||||
w2 = manager.create_task(dummy_coro())
|
||||
|
||||
found = manager.get_task_by_id(w1.id)
|
||||
assert found is w1
|
||||
|
||||
not_found = manager.get_task_by_id(9999)
|
||||
assert not_found is None
|
||||
|
||||
w1.cancel()
|
||||
w2.cancel()
|
||||
|
||||
Reference in New Issue
Block a user