mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 07:46:02 +00:00
Feat/test build (#2174)
* fix(ci): update unit-test workflow paths to match current source layout Replace stale pkg/** filter with src/langbot/** and add uv.lock. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * docs(tests): update README to reflect current test layout - Fix stale paths: tests/pipeline → tests/unit_tests/pipeline - Update CI Python versions: 3.11, 3.12, 3.13 - Add test directory structure for box, config, platform, plugin, provider, storage - Document pytest markers and uv commands - Mention planned E2E tests Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add shared test factories package Create tests/factories/ with reusable test factories: - FakeApp: mock application with all dependencies - Message chains: text_chain, mention_chain, image_chain - Query factories: text_query, group_text_query, command_query, etc. No test changes - maintains backward compatibility. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add fake provider factory Add tests/factories/provider.py with: - FakeProvider: deterministic fake LLM provider - Error simulation: timeout, auth, rate-limit, malformed - Request capture for assertions - fake_model: mock model with attached provider Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add fake platform factory Add tests/factories/platform.py with: - FakePlatform: simulated platform adapter - Inbound message construction: friend/group/image - Mention-bot flag simulation - Outbound message capture for assertions - Streaming output support simulation - Send failure simulation Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add comprehensive message/query factories Extend tests/factories/message.py with: - file_query: file attachment query - unsupported_query: unknown message segment - voice_query: audio/voice query - at_all_query: group @All mention - query_with_session: query with session object - query_with_config: query with custom pipeline config Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add fake message flow smoke test Create tests/smoke/test_fake_message_flow.py: - TestFakeMessageFlow: factory verification tests - TestMessageFlowIntegration: minimal flow smoke test - Tests FakeApp, FakeProvider, FakePlatform, query factories - Verifies LANGBOT_FAKE_PONG marker response - Captures outbound messages for assertions Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add developer test-quick command Add scripts/test-quick.sh and Makefile with: - test-quick: runs ruff check + unit tests + smoke tests - No real provider keys or platform accounts required - Suitable for local branch self-test Update tests/README.md: - Document test-quick command - Document test factories package - Add smoke tests and factories directory structure Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(test): make test-quick reliable as developer gate Fixes for D-001验收问题: 1. test-quick.sh: use set -euo pipefail, uv run ruff, no tail pipe 2. Remove unused imports in factories (app.py, platform.py, provider.py) 3. Fix unused variable in smoke test 4. Add noqa: E402 to test_n8nsvapi.py lazy imports 5. Update smoke test docs: "minimal fake flow" not full pipeline Now test-quick is a reliable gate: lint failures exit 1, test failures propagate. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(unit): add preproc and taskmgr unit tests U-001: Pipeline Preprocessor tests - Normal text message processing - Empty message handling - Image segment with/without vision model - Model selection and fallback - Variable extraction U-004: Core Task Manager tests (pattern-based) - Task creation and tracking patterns - Task cancellation patterns - Scope-based cancellation - Task type filtering - Pruning completed tasks - Wait all tasks Taskmgr tests use pattern-based approach to avoid circular import in source code (taskmgr → app → http_controller → migration → taskmgr). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(unit): add config loader unit tests U-005: Config Loader tests - Valid YAML config loading - Valid JSON config loading - Invalid YAML/JSON error behavior - Missing config file creation from template - Template completion for missing keys - ConfigManager load/dump operations - Exists check for both YAML and JSON All tests use tmp_path fixture, no real project config. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(unit): add chat and command handler pattern tests U-002: Chat Handler tests (pattern-based) - Normal message event emission pattern - prevent_default handling - User message alteration pattern - Runner selection pattern - Streaming/non-streaming response patterns - Exception handling modes (show-error, show-hint, hide) - Message history update pattern - Telemetry payload pattern U-003: Command Handler tests (pattern-based) - Command parsing and text extraction - Event creation pattern - Privilege/admin check pattern - Command result handling (text, error, image) - prevent_default handling - String truncation helper Uses pattern-based testing to avoid circular import issues in source code. Direct imports of handler modules trigger circular import chain. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * style: fix unused imports after ruff auto-fix Remove unused imports in test files: - test_config_loader.py: remove unused os - test_taskmgr.py: remove unused Mock - test_preproc.py: remove unused unsupported_query, image_chain Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(unit): improve taskmgr tests to test real classes U-004 improved: Tests now import and test actual classes: - TaskContext: new(), trace(), to_dict(), placeholder() - TaskWrapper: task creation, context, exception/result capture, cancel, to_dict - AsyncTaskManager: create_task, create_user_task, cancel_task, cancel_by_scope - Task pruning behavior Uses pre-mocking technique: - Mock langbot.pkg.core.app before import (breaks circular chain) - Mock langbot.pkg.core.entities with proper Enum All 24 tests now test real class behavior, not patterns. taskmgr.py coverage should improve significantly. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * refactor(test): consolidate FakeApp and add sys.modules isolation utility - Extract tests/utils/import_isolation.py with isolated_sys_modules context manager - Extend tests/factories/app.py FakeApp with handler-specific attributes - Refactor test_chat_handler.py to use centralized FakeApp and cached imports - Refactor test_command_handler.py with mock_execute_factory fixture - Refactor test_smoke.py to move import-time sys.modules manipulation into fixture - Add SQLite migration integration tests (G-002) - Add HTTP API smoke integration tests (G-005) - Update CI workflow to call pytest for SQLite migrations (G-004) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add developer quality gate consolidation (G-007) - Add scripts/test-integration-fast.sh for fast integration tests - Add scripts/test-coverage.sh with 12% baseline threshold - Update Makefile with test-integration-fast, test-coverage, test-all-local - Update CI workflow with integration and coverage jobs - Add smoke marker to pytest.ini - Update tests/README.md with quality gate layers documentation - Add tests/integration/pipeline/ for pipeline stage-chain tests Quality gate layers: - Quick: ruff + unit + smoke (~2 min) - Fast Integration: SQLite/API/Pipeline (~3 min) - Coverage: 12% threshold gate (~8 min) - Full Local: all three combined Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): add PostgreSQL migration slow integration tests (G-003) - Add tests/integration/persistence/test_migrations_postgres.py - All tests marked with @pytest.mark.slow - Tests skip when TEST_POSTGRES_URL is not set (no local PostgreSQL) - Database isolation via clean_tables and clean_alembic_version fixtures - Update CI workflow to use pytest instead of inline Python script - Remove TODO(G-003) comment - Update tests/README.md with PostgreSQL test documentation Covered scenarios: - Baseline stamp sets revision - Upgrade from baseline to head - Upgrade idempotent - Get current on unstamped DB returns None Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(test): Phase 1.5 coverage expansion - COV-001 to COV-013 Coverage baseline raised from 13.65% to 26% (+12.35%) Gate raised from 12% to 18% Tasks completed: - COV-001: Command system unit tests (100% coverage) - COV-002: API service unit tests batch 1 (user/apikey/model/provider) - COV-003: Provider model manager unit tests - COV-004: Pipeline remaining stage tests (aggregator/cntfilter/longtext/msgtrun) - COV-005: Storage and utils coverage pass - COV-006: Gate ratchet 12%→15% - COV-007: Gate ratchet 15%→18% - COV-008: API service batch 2 (bot/pipeline/webhook/space/maintenance/mcp) - COV-009: Blocked - API controller circular import issue documented - COV-010: Plugin runtime unit tests (+0.08%) - COV-011: RAG and vector unit tests (+0.68%) - COV-012: Core boot and migration unit tests - COV-013: Provider requester logic unit tests (+0.62%) Key additions: - tests/utils/import_isolation.py: sys.modules isolation for circular imports - Provider requester mock tests: proved HTTP-dependent code can be tested locally - Vector filter utilities: 100% coverage on pure functions - API services: fake persistence pattern for unit testing Blocked issue COV-009 documented in langbot-test-plan/1.5/issues/ Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(phase1): add unit tests for telemetry, plugin, rag, persistence Add initial unit tests for Phase 1 of test coverage improvement: - telemetry: test initialization, payload sanitization, early returns (14.3% → 62.9%) - plugin: test _parse_plugin_id static method - rag: test _to_i18n_name static method - persistence: test serialize_model with datetime handling Overall core coverage: 41.9% → 42.2% Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(phase2): add unit tests for core, persistence, plugin, utils - Add test_handler_helpers.py for plugin handler helpers (7 tests) - Add test_mgr_methods.py for persistence manager (5 tests) - Add test_app_config_validation.py for core app config (12 tests) - Add test_knowledge_service.py for API knowledge service (22 tests) - Add test_kbmgr.py for RAG knowledge base manager (39 tests) - Add test_survey_manager.py for survey manager (22 tests) - Add test_connector_methods.py for plugin connector (24 tests) - Add test_funcschema.py for utils function schema (9 tests) - Add test_platform.py for utils platform detection (7 tests) - Add test_extract_deps.py for plugin deps extraction (7 tests) - Add test_database_decorator.py for persistence decorator (7 tests) - Add test_load_config.py for core config loading (19 tests) - Add COVERAGE_EXCLUSIONS.md documenting external adapter exclusions - Fix test_chat_session_limit.py path for portability Coverage: core 28% → 30%, persistence 24% → 24.4%, plugin 27% → 28% Total: 1082 tests passed, core module coverage 45.5% Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(integration): add API controller integration tests - Add test_pipelines.py (10 tests) covering pipelines CRUD operations - GET/POST/PUT/DELETE on /api/v1/pipelines - Extensions endpoint - Metadata endpoint - Coverage: pipelines controller 27% → 80% - Add test_providers.py (10 tests) covering provider/model management - Provider CRUD with model counts - LLM model CRUD - Coverage: providers controller 23% → 81%, models 29% → 45% Tests use Quart TestClient with mocked services for real HTTP behavior without external dependencies. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(integration): add knowledge, bots, and model endpoints tests - Add test_knowledge.py (10 tests) covering knowledge base management - CRUD operations on /api/v1/knowledge/bases - Files management endpoints - Retrieve endpoint with validation - Coverage: knowledge/base.py 26% → 91% - Add test_bots.py (9 tests) covering bot management - CRUD operations on /api/v1/platform/bots - Logs endpoint - Send message endpoint with validation - Coverage: platform/bots.py 24% → 87% - Extend test_providers.py (+4 tests) for embedding/rerank models - Embedding models CRUD - Rerank models CRUD - Coverage: provider/models.py 29% → 60% Total integration tests: 53 (smoke 12 + pipelines 10 + providers 14 + knowledge 10 + bots 9) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(integration): add embed and monitoring endpoint tests Add integration tests for embed widget and monitoring API endpoints: - test_embed.py: 15 tests for widget.js, logo, turnstile, messages, reset, feedback - test_monitoring.py: 15 tests for overview, messages, llm-calls, sessions, errors, export Coverage improvements: - embed.py: 17% → 56% - monitoring.py: 17% → 93% Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(e2e): add minimal startup E2E tests Add E2E tests for LangBot startup flow: - tests/e2e/utils/config_factory.py: minimal config generation - tests/e2e/utils/process_manager.py: LangBot subprocess management - tests/e2e/conftest.py: E2E fixtures (session-scoped process) - tests/e2e/test_startup.py: 12 tests for startup verification Tests verify: - boot.py + stages execution - database initialization (SQLite) - API availability - migrations applied Uses embedded databases (SQLite, Chroma) - no external dependencies. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test(quality): fix fake tests and add missing coverage P0 fixes: - telemetry: rewrite fake tests with real behavior verification (25 tests) - config: delete copied-source tests, use proper imports (2 deleted) - persistence: fix try-except pass to verify specific errors P1 fixes: - pipeline: add real FixedWindowAlgo tests instead of mocks (12 tests) - provider: add SessionManager and ToolManager tests (25 tests) - storage: add S3StorageProvider tests with moto mock (16 tests) - plugin: add handler action tests for setting inheritance (15 tests) - rag: add file storage and ZIP processing tests (21 tests) - vector: add VDB filter conversion tests (30 tests) P2 fixes: - pipeline/msgtrun: strengthen assertions for exact message count - api: add response structure validation in integration tests New test files: - provider/test_session_manager.py - provider/test_tool_manager.py - storage/test_s3storage.py - plugin/test_handler_actions.py - rag/test_file_storage.py - vector/test_vdb_filter_conversion.py Source code bugs documented: - provider: TokenManager.next_token() ZeroDivisionError - telemetry: send_tasks class variable shared state - command: empty command IndexError, unused parameters - utils: funcschema KeyError - entity: vector.py independent declarative_base Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * docs(test): update coverage stats and test structure - Update coverage from 22% to 30% - Add new test files to structure: - provider: session_manager, tool_manager - storage: s3storage - plugin: handler_actions - rag: file_storage - vector: vdb_filter_conversion - telemetry: rewritten tests - Update module coverage percentages Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test: add 105 new unit tests for untested core functionality Add comprehensive tests for B-class issues (core functionality untested): Pipeline: - test_pool.py: QueryPool ID generation, caching, async context (12 tests) - test_ratelimit.py: Fixed timing-sensitive test tolerance - test_pipelinemgr.py: Use real Pydantic StageProcessResult instead of Mock Utils: - test_version.py: Version comparison functions (20 tests) - test_logcache.py: Log page management and retrieval (18 tests) - test_httpclient.py: HTTP session pool management (10 tests) - test_proxy.py: Proxy configuration from env and config (10 tests) - test_image.py: URL parsing and base64 extraction (12 tests) - test_pkgmgr.py: Pip command generation (8 tests) Discover: - test_engine.py: I18nString, Metadata, Component manifest (15 tests) Test count: 1193 → 1298 (+105 tests) Note: Some B-class issues cannot be tested due to circular import bugs filed as GitHub issues #2175 (pipeline) and #2176 (persistence). * test: tighten phase 1 coverage contracts * test: align ci integration isolation --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
1
tests/unit_tests/api/__init__.py
Normal file
1
tests/unit_tests/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for LangBot API HTTP service layer."""
|
||||
16
tests/unit_tests/api/service/__init__.py
Normal file
16
tests/unit_tests/api/service/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Unit tests for API HTTP service layer.
|
||||
|
||||
Tests real service business logic with mocked dependencies:
|
||||
- persistence_mgr (database operations)
|
||||
- model_mgr (runtime model management)
|
||||
- platform_mgr (platform management)
|
||||
- plugin_connector (plugin runtime)
|
||||
- adjacent services (cross-service calls)
|
||||
|
||||
Does NOT:
|
||||
- Start real Quart server
|
||||
- Access real database
|
||||
- Call real provider/platform/network
|
||||
|
||||
Uses tests.factories.FakeApp as base mock application.
|
||||
"""
|
||||
429
tests/unit_tests/api/service/test_apikey_service.py
Normal file
429
tests/unit_tests/api/service/test_apikey_service.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Unit tests for ApiKeyService.
|
||||
|
||||
Tests API key CRUD operations with mocked persistence layer.
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/apikey.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.apikey import ApiKeyService
|
||||
from langbot.pkg.entity.persistence.apikey import ApiKey
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestApiKeyServiceGetApiKeys:
|
||||
"""Tests for get_api_keys method."""
|
||||
|
||||
async def test_get_api_keys_empty_list(self):
|
||||
"""Returns empty list when no API keys exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=[])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
if entity
|
||||
else {}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_keys()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_get_api_keys_returns_serialized_list(self):
|
||||
"""Returns serialized list of API keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create mock API key entities
|
||||
key1 = Mock(spec=ApiKey)
|
||||
key1.id = 1
|
||||
key1.name = 'Test Key 1'
|
||||
key1.key = 'lbk_test_key_1'
|
||||
key1.description = 'First test key'
|
||||
|
||||
key2 = Mock(spec=ApiKey)
|
||||
key2.id = 2
|
||||
key2.name = 'Test Key 2'
|
||||
key2.key = 'lbk_test_key_2'
|
||||
key2.description = 'Second test key'
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=[key1, key2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_keys()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Test Key 1'
|
||||
assert result[1]['name'] == 'Test Key 2'
|
||||
|
||||
|
||||
class TestApiKeyServiceCreateApiKey:
|
||||
"""Tests for create_api_key method."""
|
||||
|
||||
async def test_create_api_key_generates_key_with_prefix(self):
|
||||
"""Creates API key with 'lbk_' prefix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_key = Mock(spec=ApiKey)
|
||||
created_key.id = 1
|
||||
created_key.name = 'New Key'
|
||||
created_key.key = 'lbk_fixed-token'
|
||||
created_key.description = 'Test description'
|
||||
select_result = Mock()
|
||||
select_result.first = Mock(return_value=created_key)
|
||||
insert_params = []
|
||||
|
||||
async def mock_execute(query):
|
||||
params = query.compile().params
|
||||
if {'name', 'key', 'description'}.issubset(params):
|
||||
insert_params.append(params)
|
||||
return Mock()
|
||||
return select_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': 1,
|
||||
'name': entity.name,
|
||||
'key': entity.key,
|
||||
'description': entity.description,
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
|
||||
result = await service.create_api_key('New Key', 'Test description')
|
||||
|
||||
assert insert_params == [
|
||||
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
|
||||
]
|
||||
assert result['key'].startswith('lbk_')
|
||||
assert result['key'] == 'lbk_fixed-token'
|
||||
assert result['name'] == 'New Key'
|
||||
assert result['description'] == 'Test description'
|
||||
|
||||
async def test_create_api_key_without_description(self):
|
||||
"""Creates API key with empty description when not provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_key = Mock(spec=ApiKey)
|
||||
created_key.id = 1
|
||||
created_key.name = 'No Desc Key'
|
||||
created_key.key = 'lbk_no_desc_key'
|
||||
created_key.description = ''
|
||||
|
||||
select_result = Mock()
|
||||
select_result.first = Mock(return_value=created_key)
|
||||
insert_result = Mock()
|
||||
|
||||
async def mock_execute(query):
|
||||
if hasattr(query, 'values'):
|
||||
return insert_result
|
||||
return select_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'No Desc Key',
|
||||
'key': 'lbk_no_desc_key',
|
||||
'description': '',
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_api_key('No Desc Key')
|
||||
|
||||
# Verify
|
||||
assert result['description'] == ''
|
||||
|
||||
|
||||
class TestApiKeyServiceGetApiKey:
|
||||
"""Tests for get_api_key method."""
|
||||
|
||||
async def test_get_api_key_by_id_found(self):
|
||||
"""Returns API key when found by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
key = Mock(spec=ApiKey)
|
||||
key.id = 1
|
||||
key.name = 'Found Key'
|
||||
key.key = 'lbk_found_key'
|
||||
key.description = 'Found'
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=key)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Found Key',
|
||||
'key': 'lbk_found_key',
|
||||
'description': 'Found',
|
||||
}
|
||||
)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(1)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['id'] == 1
|
||||
assert result['name'] == 'Found Key'
|
||||
|
||||
async def test_get_api_key_by_id_not_found(self):
|
||||
"""Returns None when API key not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(999)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_api_key_by_id_zero(self):
|
||||
"""Handles ID=0 (edge case) correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_api_key(0)
|
||||
|
||||
# Verify - should return None (no key with ID 0)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestApiKeyServiceVerifyApiKey:
|
||||
"""Tests for verify_api_key method."""
|
||||
|
||||
async def test_verify_api_key_valid(self):
|
||||
"""Returns True for valid API key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
key = Mock(spec=ApiKey)
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=key)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('lbk_valid_key')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
async def test_verify_api_key_invalid(self):
|
||||
"""Returns False for invalid API key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('lbk_invalid_key')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_verify_api_key_empty_string(self):
|
||||
"""Returns False for empty key string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_verify_api_key_unknown_key(self):
|
||||
"""Returns False when the key is not present in persistence."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.verify_api_key('unknown_key')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestApiKeyServiceDeleteApiKey:
|
||||
"""Tests for delete_api_key method."""
|
||||
|
||||
async def test_delete_api_key_by_id(self):
|
||||
"""Deletes API key by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_api_key(1)
|
||||
|
||||
# Verify - execute_async was called (delete operation)
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_api_key_nonexistent_id(self):
|
||||
"""Delete operation completes even for nonexistent ID (no error raised)."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute - should not raise error
|
||||
await service.delete_api_key(999)
|
||||
|
||||
# Verify - execute_async was called regardless
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestApiKeyServiceUpdateApiKey:
|
||||
"""Tests for update_api_key method."""
|
||||
|
||||
async def test_update_api_key_name_only(self):
|
||||
"""Updates only the name field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, name='Updated Name')
|
||||
|
||||
# Verify - execute_async was called with update
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_description_only(self):
|
||||
"""Updates only the description field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, description='Updated description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_both_fields(self):
|
||||
"""Updates both name and description."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1, name='New Name', description='New description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_api_key_no_fields(self):
|
||||
"""Does nothing when no fields provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ApiKeyService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_api_key(1)
|
||||
|
||||
# Verify - no execute call since no update_data
|
||||
ap.persistence_mgr.execute_async.assert_not_called()
|
||||
662
tests/unit_tests/api/service/test_bot_service.py
Normal file
662
tests/unit_tests/api/service/test_bot_service.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""
|
||||
Unit tests for BotService.
|
||||
|
||||
Tests bot CRUD operations with mocked persistence and runtime managers.
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/bot.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
|
||||
from langbot.pkg.api.http.service.bot import BotService
|
||||
from langbot.pkg.entity.persistence.bot import Bot
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_bot(
|
||||
bot_uuid: str = None,
|
||||
name: str = 'Test Bot',
|
||||
description: str = 'Test Description',
|
||||
adapter: str = 'telegram',
|
||||
adapter_config: dict = None,
|
||||
enable: bool = True,
|
||||
use_pipeline_uuid: str = None,
|
||||
use_pipeline_name: str = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock Bot entity."""
|
||||
bot = Mock(spec=Bot)
|
||||
bot.uuid = bot_uuid or str(uuid.uuid4())
|
||||
bot.name = name
|
||||
bot.description = description
|
||||
bot.adapter = adapter
|
||||
bot.adapter_config = adapter_config or {'token': 'test_token'}
|
||||
bot.enable = enable
|
||||
bot.use_pipeline_uuid = use_pipeline_uuid
|
||||
bot.use_pipeline_name = use_pipeline_name
|
||||
bot.pipeline_routing_rules = []
|
||||
return bot
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestBotServiceGetBots:
|
||||
"""Tests for get_bots method."""
|
||||
|
||||
async def test_get_bots_empty_list(self):
|
||||
"""Returns empty list when no bots exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_bots_returns_list_with_secrets(self):
|
||||
"""Returns bot list including adapter_config by default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
|
||||
bot2 = _create_mock_bot(bot_uuid='uuid-2', name='Bot 2')
|
||||
|
||||
mock_result = _create_mock_result([bot1, bot2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots(include_secret=True)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Bot 1'
|
||||
assert result[0]['adapter_config'] is not None
|
||||
|
||||
async def test_get_bots_masks_secrets(self):
|
||||
"""Returns bot list without adapter_config when include_secret=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
|
||||
|
||||
mock_result = _create_mock_result([bot1])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity, masked_columns=None: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'adapter': entity.adapter,
|
||||
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bots(include_secret=False)
|
||||
|
||||
# Verify - adapter_config should be masked
|
||||
assert result[0]['adapter_config'] is None
|
||||
|
||||
|
||||
class TestBotServiceGetBot:
|
||||
"""Tests for get_bot method."""
|
||||
|
||||
async def test_get_bot_by_uuid_found(self):
|
||||
"""Returns bot when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
bot = _create_mock_bot(bot_uuid='test-uuid', name='Found Bot')
|
||||
mock_result = _create_mock_result(first_item=bot)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Bot',
|
||||
'adapter': 'telegram',
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bot('test-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'test-uuid'
|
||||
assert result['name'] == 'Found Bot'
|
||||
|
||||
async def test_get_bot_by_uuid_not_found(self):
|
||||
"""Returns None when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_bot('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBotServiceGetRuntimeBotInfo:
|
||||
"""Tests for get_runtime_bot_info method."""
|
||||
|
||||
async def test_get_runtime_bot_info_bot_not_found_raises(self):
|
||||
"""Raises Exception when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Mock get_bot to return None
|
||||
service.get_bot = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.get_runtime_bot_info('nonexistent-uuid')
|
||||
|
||||
async def test_get_runtime_bot_info_returns_webhook_for_wecom(self):
|
||||
"""Returns webhook URL for wecom adapter."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'api': {
|
||||
'webhook_prefix': 'http://127.0.0.1:5300',
|
||||
'extra_webhook_prefix': 'http://extra.example.com',
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'wecom-uuid',
|
||||
'name': 'WeCom Bot',
|
||||
'adapter': 'wecom',
|
||||
'adapter_config': {'token': 'test'},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('wecom-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['adapter_runtime_values']['webhook_url'] == '/bots/wecom-uuid'
|
||||
assert result['adapter_runtime_values']['webhook_full_url'] == 'http://127.0.0.1:5300/bots/wecom-uuid'
|
||||
|
||||
async def test_get_runtime_bot_info_no_webhook_for_telegram(self):
|
||||
"""Returns no webhook URL for non-webhook adapters like telegram."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'api': {}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'telegram-uuid',
|
||||
'name': 'Telegram Bot',
|
||||
'adapter': 'telegram',
|
||||
'adapter_config': {'token': 'test'},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('telegram-uuid')
|
||||
|
||||
# Verify - no webhook for telegram
|
||||
assert result['adapter_runtime_values']['webhook_url'] is None
|
||||
assert result['adapter_runtime_values']['webhook_full_url'] is None
|
||||
|
||||
async def test_get_runtime_bot_info_with_runtime_bot(self):
|
||||
"""Returns bot_account_id when runtime bot exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'api': {}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
# Mock runtime bot with adapter
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.bot_account_id = 'runtime-account-123'
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
bot_data = {
|
||||
'uuid': 'runtime-uuid',
|
||||
'name': 'Runtime Bot',
|
||||
'adapter': 'telegram',
|
||||
'adapter_config': {},
|
||||
}
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value=bot_data)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_bot_info('runtime-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['adapter_runtime_values']['bot_account_id'] == 'runtime-account-123'
|
||||
|
||||
|
||||
class TestBotServiceCreateBot:
|
||||
"""Tests for create_bot method."""
|
||||
|
||||
async def test_create_bot_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_bots limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_bots': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock get_bots to return 2 bots already
|
||||
bot1 = _create_mock_bot(bot_uuid='uuid-1')
|
||||
bot2 = _create_mock_bot(bot_uuid='uuid-2')
|
||||
mock_result = _create_mock_result([bot1, bot2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of bots'):
|
||||
await service.create_bot({'name': 'New Bot'})
|
||||
|
||||
async def test_create_bot_no_limit(self):
|
||||
"""Creates bot without limit check when max_bots=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_bots': -1 # No limit
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=None)
|
||||
# Mock bot query after insert
|
||||
bot_result = Mock()
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return pipeline_result # First call: check pipeline
|
||||
elif call_count == 3:
|
||||
return Mock() # Insert
|
||||
return bot_result # Get bot
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
bot_uuid = await service.create_bot({'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}})
|
||||
|
||||
# Verify
|
||||
assert bot_uuid is not None
|
||||
assert len(bot_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_bot_sets_default_pipeline(self):
|
||||
"""Sets default pipeline when one exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_bots': -1}}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
# Mock default pipeline
|
||||
mock_pipeline = SimpleNamespace()
|
||||
mock_pipeline.uuid = 'default-pipeline-uuid'
|
||||
mock_pipeline.name = 'Default Pipeline'
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||
|
||||
# Mock bot after insert
|
||||
bot_result = Mock()
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return pipeline_result # Check default pipeline
|
||||
elif call_count == 2:
|
||||
return Mock() # Insert
|
||||
return bot_result # Get bot
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'name': 'New Bot',
|
||||
'use_pipeline_uuid': 'default-pipeline-uuid',
|
||||
'use_pipeline_name': 'Default Pipeline',
|
||||
}
|
||||
)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
bot_data = {'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}}
|
||||
bot_uuid = await service.create_bot(bot_data)
|
||||
|
||||
# Verify - pipeline uuid and name were set
|
||||
assert 'use_pipeline_uuid' in bot_data
|
||||
assert 'use_pipeline_name' in bot_data
|
||||
assert bot_uuid is not None # Verify UUID was returned
|
||||
|
||||
|
||||
class TestBotServiceUpdateBot:
|
||||
"""Tests for update_bot method."""
|
||||
|
||||
async def test_update_bot_removes_uuid_from_data(self):
|
||||
"""Does not persist caller-provided uuid in update payload."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query - not updating pipeline
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
|
||||
|
||||
# Create mock runtime bot
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.enable = False
|
||||
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
# Execute
|
||||
update_data = {'uuid': 'should-be-removed', 'name': 'Updated Name'}
|
||||
await service.update_bot('test-uuid', update_data)
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
|
||||
assert update_params['name'] == 'Updated Name'
|
||||
assert 'should-be-removed' not in update_params.values()
|
||||
|
||||
async def test_update_bot_pipeline_not_found_raises(self):
|
||||
"""Raises Exception when updating with nonexistent pipeline UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock pipeline query returns None
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=pipeline_result)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Pipeline not found'):
|
||||
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'nonexistent-pipeline'})
|
||||
|
||||
async def test_update_bot_sets_pipeline_name(self):
|
||||
"""Sets use_pipeline_name when updating use_pipeline_uuid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
# Mock pipeline query
|
||||
mock_pipeline = SimpleNamespace()
|
||||
mock_pipeline.name = 'Updated Pipeline'
|
||||
pipeline_result = Mock()
|
||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return pipeline_result
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
|
||||
service = BotService(ap)
|
||||
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid'})
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.enable = False
|
||||
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
# Execute
|
||||
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'pipeline-uuid'})
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[1].args[0].compile().params
|
||||
assert update_params['use_pipeline_uuid'] == 'pipeline-uuid'
|
||||
assert update_params['use_pipeline_name'] == 'Updated Pipeline'
|
||||
|
||||
|
||||
class TestBotServiceDeleteBot:
|
||||
"""Tests for delete_bot method."""
|
||||
|
||||
async def test_delete_bot_calls_remove_and_delete(self):
|
||||
"""Calls both platform_mgr.remove_bot and persistence delete."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_bot('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.platform_mgr.remove_bot.assert_called_once_with('test-uuid')
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_bot_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.remove_bot = AsyncMock()
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_bot('nonexistent-uuid')
|
||||
|
||||
# Verify - both called regardless
|
||||
ap.platform_mgr.remove_bot.assert_called_once()
|
||||
|
||||
|
||||
class TestBotServiceListEventLogs:
|
||||
"""Tests for list_event_logs method."""
|
||||
|
||||
async def test_list_event_logs_bot_not_found_raises(self):
|
||||
"""Raises Exception when runtime bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.list_event_logs('nonexistent-uuid', 0, 10)
|
||||
|
||||
async def test_list_event_logs_returns_logs(self):
|
||||
"""Returns logs from runtime bot logger."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
# Mock runtime bot with logger
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.logger = SimpleNamespace()
|
||||
runtime_bot.logger.get_logs = AsyncMock(return_value=(
|
||||
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
|
||||
5
|
||||
))
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute
|
||||
logs, total = await service.list_event_logs('bot-uuid', 0, 10)
|
||||
|
||||
# Verify
|
||||
assert len(logs) == 1
|
||||
assert logs[0] == {'msg': 'log1'}
|
||||
assert total == 5
|
||||
|
||||
|
||||
class TestBotServiceSendMessage:
|
||||
"""Tests for send_message method."""
|
||||
|
||||
async def test_send_message_bot_not_found_raises(self):
|
||||
"""Raises Exception when bot not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='Bot not found'):
|
||||
await service.send_message('nonexistent-uuid', 'group', '123', {'test': 'data'})
|
||||
|
||||
async def test_send_message_invalid_message_chain_raises(self):
|
||||
"""Raises Exception when message_chain_data is invalid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.send_message = AsyncMock()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute & Verify - invalid format should raise
|
||||
with pytest.raises(Exception, match='Invalid message_chain format'):
|
||||
await service.send_message('bot-uuid', 'group', '123', {'invalid': 'format'})
|
||||
|
||||
async def test_send_message_valid_call(self):
|
||||
"""Sends message through adapter when all valid."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.adapter = SimpleNamespace()
|
||||
runtime_bot.adapter.send_message = AsyncMock()
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute with valid message chain format
|
||||
message_chain_data = {
|
||||
'messages': [
|
||||
{'type': 'text', 'data': {'text': 'Hello'}}
|
||||
]
|
||||
}
|
||||
|
||||
# Patch the import location - the module imports inside the function
|
||||
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
|
||||
mock_chain = Mock()
|
||||
MockMessageChain.model_validate = Mock(return_value=mock_chain)
|
||||
await service.send_message('bot-uuid', 'group', '123', message_chain_data)
|
||||
|
||||
# Verify adapter.send_message was called
|
||||
runtime_bot.adapter.send_message.assert_called_once_with('group', '123', mock_chain)
|
||||
397
tests/unit_tests/api/service/test_knowledge_service.py
Normal file
397
tests/unit_tests/api/service/test_knowledge_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Unit tests for API knowledge service.
|
||||
|
||||
Tests cover:
|
||||
- Knowledge base CRUD operations
|
||||
- Capability checking
|
||||
- Knowledge engine discovery
|
||||
- File operations
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def get_knowledge_service_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
return import_module('langbot.pkg.api.http.service.knowledge')
|
||||
|
||||
|
||||
def create_mock_app():
|
||||
"""Create mock Application for testing."""
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.rag_mgr = AsyncMock()
|
||||
mock_app.persistence_mgr = AsyncMock()
|
||||
mock_app.persistence_mgr.execute_async = AsyncMock()
|
||||
mock_app.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
mock_app.plugin_connector = AsyncMock()
|
||||
mock_app.plugin_connector.is_enable_plugin = True
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestKnowledgeServiceInit:
|
||||
"""Tests for KnowledgeService initialization."""
|
||||
|
||||
def test_init_stores_app_reference(self):
|
||||
"""Test that __init__ stores Application reference."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
assert service.ap is mock_app
|
||||
|
||||
|
||||
class TestGetKnowledgeBases:
|
||||
"""Tests for get_knowledge_bases method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_kb_details(self):
|
||||
"""Test that it returns all knowledge base details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
|
||||
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_bases()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['uuid'] == 'kb1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_list_when_no_kbs(self):
|
||||
"""Test that it returns empty list when no knowledge bases."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[])
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_bases()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetKnowledgeBase:
|
||||
"""Tests for get_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_kb_details_by_uuid(self):
|
||||
"""Test that it returns specific KB details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'KB1'}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_base('kb1')
|
||||
|
||||
assert result['uuid'] == 'kb1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found(self):
|
||||
"""Test that it returns None when KB not found."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_base('nonexistent')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreateKnowledgeBase:
|
||||
"""Tests for create_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_kb_with_required_fields(self):
|
||||
"""Test creating KB with required plugin ID."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_kb = Mock()
|
||||
mock_kb.uuid = 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
kb_data = {
|
||||
'name': 'Test KB',
|
||||
'knowledge_engine_plugin_id': 'author/engine',
|
||||
'description': 'Test description',
|
||||
}
|
||||
|
||||
result = await service.create_knowledge_base(kb_data)
|
||||
|
||||
assert result == 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_missing_plugin_id(self):
|
||||
"""Test that ValueError is raised when plugin ID missing."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.create_knowledge_base({'name': 'Test'})
|
||||
|
||||
assert 'knowledge_engine_plugin_id is required' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_with_default_name(self):
|
||||
"""Test that KB is created with default name if not provided."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_kb = Mock()
|
||||
mock_kb.uuid = 'new_kb_uuid'
|
||||
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
await service.create_knowledge_base({
|
||||
'knowledge_engine_plugin_id': 'author/engine'
|
||||
})
|
||||
|
||||
# Check that default name 'Untitled' was used
|
||||
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
|
||||
assert call_args.kwargs['name'] == 'Untitled'
|
||||
|
||||
|
||||
class TestUpdateKnowledgeBase:
|
||||
"""Tests for update_knowledge_base method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_mutable_fields_only(self):
|
||||
"""Test that only mutable fields are updated."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'Updated'}
|
||||
)
|
||||
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
|
||||
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
# Pass both mutable and immutable fields
|
||||
await service.update_knowledge_base('kb1', {
|
||||
'name': 'New Name',
|
||||
'description': 'New desc',
|
||||
'uuid': 'should_be_filtered', # immutable
|
||||
})
|
||||
|
||||
# Check that only mutable fields were passed to update
|
||||
call_args = mock_app.persistence_mgr.execute_async.call_args
|
||||
assert call_args is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_early_when_no_mutable_fields(self):
|
||||
"""Test that update returns early when no mutable fields provided."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
# Pass only immutable fields
|
||||
await service.update_knowledge_base('kb1', {'uuid': 'should_be_filtered'})
|
||||
|
||||
# No DB update should be called
|
||||
mock_app.persistence_mgr.execute_async.assert_not_called()
|
||||
|
||||
|
||||
class TestCheckDocCapability:
|
||||
"""Tests for _check_doc_capability method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_when_capability_supported(self):
|
||||
"""Test that check passes when doc_ingestion capability exists."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'knowledge_engine': {'capabilities': ['doc_ingestion']}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
await service._check_doc_capability('kb1', 'document upload')
|
||||
|
||||
# No exception raised means success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_kb_not_found(self):
|
||||
"""Test that Exception is raised when KB not found."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await service._check_doc_capability('nonexistent', 'test operation')
|
||||
|
||||
assert 'Knowledge base not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_capability_not_supported(self):
|
||||
"""Test that Exception is raised when doc_ingestion not in capabilities."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'knowledge_engine': {'capabilities': ['other_capability']}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await service._check_doc_capability('kb1', 'document upload')
|
||||
|
||||
assert 'does not support document upload' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestListKnowledgeEngines:
|
||||
"""Tests for list_knowledge_engines method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_engines_from_plugin_connector(self):
|
||||
"""Test that it returns knowledge engines from plugin connector."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||
return_value=[{'id': 'engine1', 'name': 'Engine 1'}]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['id'] == 'engine1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_plugin_disabled(self):
|
||||
"""Test that it returns empty list when plugin disabled."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.is_enable_plugin = False
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_on_exception(self):
|
||||
"""Test that it returns empty list and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||
side_effect=Exception('Connection error')
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
|
||||
assert result == []
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
|
||||
|
||||
class TestListParsers:
|
||||
"""Tests for list_parsers method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_parsers(self):
|
||||
"""Test that it returns all parsers when no MIME type filter."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_parsers = AsyncMock(
|
||||
return_value=[
|
||||
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
|
||||
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
|
||||
]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers()
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_mime_type(self):
|
||||
"""Test that it filters parsers by MIME type."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_parsers = AsyncMock(
|
||||
return_value=[
|
||||
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
|
||||
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
|
||||
]
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers(mime_type='application/pdf')
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]['id'] == 'parser2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_plugin_disabled(self):
|
||||
"""Test that it returns empty list when plugin disabled."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.is_enable_plugin = False
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_parsers()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetEngineSchemas:
|
||||
"""Tests for get_engine_creation_schema and get_engine_retrieval_schema."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_creation_schema(self):
|
||||
"""Test that it returns creation schema for engine."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||
return_value={'properties': {'name': {'type': 'string'}}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_creation_schema('author/engine')
|
||||
|
||||
assert 'properties' in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_retrieval_schema(self):
|
||||
"""Test that it returns retrieval schema for engine."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_retrieval_schema = AsyncMock(
|
||||
return_value={'properties': {'top_k': {'type': 'integer'}}}
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_retrieval_schema('author/engine')
|
||||
|
||||
assert 'properties' in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dict_on_exception(self):
|
||||
"""Test that it returns empty dict and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||
side_effect=Exception('Plugin error')
|
||||
)
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_creation_schema('author/engine')
|
||||
|
||||
assert result == {}
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
824
tests/unit_tests/api/service/test_maintenance_service.py
Normal file
824
tests/unit_tests/api/service/test_maintenance_service.py
Normal file
@@ -0,0 +1,824 @@
|
||||
"""
|
||||
Unit tests for MaintenanceService.
|
||||
|
||||
Tests storage maintenance and diagnostics including:
|
||||
- Cleanup expired files
|
||||
- Storage analysis
|
||||
- File counting and sizing
|
||||
- Monitoring counts
|
||||
- Binary storage stats
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/maintenance.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from langbot.pkg.api.http.service.maintenance import MaintenanceService
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_result(scalar_value=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.scalar = Mock(return_value=scalar_value)
|
||||
return result
|
||||
|
||||
|
||||
class TestMaintenanceServiceCleanupExpiredFiles:
|
||||
"""Tests for cleanup_expired_files method."""
|
||||
|
||||
async def test_cleanup_expired_files_default_retention(self):
|
||||
"""Uses default retention days when config not set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
# Create a proper mock object with __class__.__name__
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods - one is async, one is not
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async!
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify - returns counts
|
||||
assert 'uploaded_files' in result
|
||||
assert 'log_files' in result
|
||||
assert result['uploaded_files'] == 0
|
||||
assert result['log_files'] == 0
|
||||
|
||||
async def test_cleanup_expired_files_custom_retention(self):
|
||||
"""Uses custom retention days from config."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'storage': {
|
||||
'cleanup': {
|
||||
'uploaded_file_retention_days': 14,
|
||||
'log_retention_days': 7,
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=2)
|
||||
service._cleanup_expired_log_files = Mock(return_value=3) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify
|
||||
assert result['uploaded_files'] == 2
|
||||
assert result['log_files'] == 3
|
||||
|
||||
async def test_cleanup_expired_files_s3_provider(self):
|
||||
"""Handles S3StorageProvider correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
# Mock S3 provider
|
||||
s3_provider = MagicMock()
|
||||
s3_provider.__class__.__name__ = 'S3StorageProvider'
|
||||
s3_provider.delete = AsyncMock()
|
||||
ap.storage_mgr.storage_provider = s3_provider
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=1)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify
|
||||
assert result['uploaded_files'] == 1
|
||||
assert result['log_files'] == 0
|
||||
|
||||
async def test_cleanup_expired_files_invalid_retention(self):
|
||||
"""Uses default for invalid retention config."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'storage': {
|
||||
'cleanup': {
|
||||
'uploaded_file_retention_days': 'invalid', # Invalid
|
||||
'log_retention_days': 0, # Invalid (less than 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.storage_mgr = SimpleNamespace()
|
||||
|
||||
storage_provider = MagicMock()
|
||||
storage_provider.__class__.__name__ = 'LocalStorageProvider'
|
||||
ap.storage_mgr.storage_provider = storage_provider
|
||||
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock the internal cleanup methods
|
||||
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
|
||||
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
|
||||
|
||||
# Execute
|
||||
result = await service.cleanup_expired_files()
|
||||
|
||||
# Verify - warning logged, defaults used
|
||||
assert ap.logger.warning.called
|
||||
assert 'uploaded_files' in result
|
||||
|
||||
|
||||
class TestMaintenanceServiceGetStorageAnalysis:
|
||||
"""Tests for get_storage_analysis method."""
|
||||
|
||||
async def test_get_storage_analysis_basic(self):
|
||||
"""Returns basic storage analysis."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
|
||||
}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.get_stats = Mock(return_value={'running': 0})
|
||||
|
||||
# Mock monitoring counts
|
||||
count_result = _create_mock_result(scalar_value=10)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock file operations
|
||||
service._path_size = Mock(return_value=1000)
|
||||
service._file_count = Mock(return_value=5)
|
||||
service._monitoring_counts = AsyncMock(return_value={'messages': 10, 'errors': 0})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 5, 'size_bytes': 500})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert 'generated_at' in result
|
||||
assert 'cleanup_policy' in result
|
||||
assert 'sections' in result
|
||||
assert 'database' in result
|
||||
assert 'cleanup_candidates' in result
|
||||
|
||||
async def test_get_storage_analysis_sections(self):
|
||||
"""Returns all storage sections."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'database': {'use': 'postgresql'}}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify - all sections present
|
||||
sections = {s['key'] for s in result['sections']}
|
||||
assert 'database' in sections
|
||||
assert 'logs' in sections
|
||||
assert 'storage' in sections
|
||||
assert 'vector_store' in sections
|
||||
assert 'plugins' in sections
|
||||
assert 'mcp' in sections
|
||||
assert 'temp' in sections
|
||||
|
||||
async def test_get_storage_analysis_postgresql(self):
|
||||
"""Handles PostgreSQL database type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'database': {'use': 'postgresql'}}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': None})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[])
|
||||
service._expired_log_candidates = Mock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert result['database']['type'] == 'postgresql'
|
||||
|
||||
async def test_get_storage_analysis_with_cleanup_candidates(self):
|
||||
"""Returns cleanup candidates in analysis."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
ap.task_mgr = None
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
service._path_size = Mock(return_value=0)
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[
|
||||
{'key': 'old_file', 'size_bytes': 100}
|
||||
])
|
||||
service._expired_log_candidates = Mock(return_value=[
|
||||
{'name': 'old_log', 'size_bytes': 50}
|
||||
])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
|
||||
# Verify
|
||||
assert len(result['cleanup_candidates']['uploaded_files']) == 1
|
||||
assert len(result['cleanup_candidates']['log_files']) == 1
|
||||
|
||||
|
||||
class TestMaintenanceServiceMonitoringCounts:
|
||||
"""Tests for _monitoring_counts method."""
|
||||
|
||||
async def test_monitoring_counts_returns_counts(self):
|
||||
"""Returns counts for all monitoring tables."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=42)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._monitoring_counts()
|
||||
|
||||
# Verify - all table keys present
|
||||
assert 'messages' in result
|
||||
assert 'llm_calls' in result
|
||||
assert 'embedding_calls' in result
|
||||
assert 'errors' in result
|
||||
assert 'sessions' in result
|
||||
assert 'feedback' in result
|
||||
|
||||
async def test_monitoring_counts_zero_results(self):
|
||||
"""Returns zero counts when tables empty."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=0)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._monitoring_counts()
|
||||
|
||||
# Verify - all zero
|
||||
assert all(v == 0 for v in result.values())
|
||||
|
||||
|
||||
class TestMaintenanceServiceBinaryStorageStats:
|
||||
"""Tests for _binary_storage_stats method."""
|
||||
|
||||
async def test_binary_storage_stats_returns_stats(self):
|
||||
"""Returns count and size for binary storage."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
# Mock count result
|
||||
count_result = _create_mock_result(scalar_value=10)
|
||||
# Mock size result
|
||||
size_result = _create_mock_result(scalar_value=5000)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return count_result
|
||||
return size_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._binary_storage_stats()
|
||||
|
||||
# Verify
|
||||
assert result['count'] == 10
|
||||
assert result['size_bytes'] == 5000
|
||||
|
||||
async def test_binary_storage_stats_size_error(self):
|
||||
"""Handles error when calculating size."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
count_result = _create_mock_result(scalar_value=5)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return count_result
|
||||
raise Exception('Size calculation error')
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._binary_storage_stats()
|
||||
|
||||
# Verify - warning logged, size_bytes None or 0
|
||||
assert ap.logger.warning.called
|
||||
assert result['count'] == 5
|
||||
|
||||
|
||||
class TestMaintenanceServicePathSize:
|
||||
"""Tests for _path_size method."""
|
||||
|
||||
def test_path_size_nonexistent_path(self):
|
||||
"""Returns 0 for nonexistent path."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._path_size(Path('/nonexistent/path'))
|
||||
|
||||
# Verify
|
||||
assert result == 0
|
||||
|
||||
def test_path_size_single_file(self):
|
||||
"""Returns size for single file."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock file
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=True):
|
||||
with patch.object(Path, 'stat', return_value=mock_stat):
|
||||
result = service._path_size(Path('test.txt'))
|
||||
|
||||
# Verify
|
||||
assert result == 100
|
||||
|
||||
def test_path_size_directory(self):
|
||||
"""Returns total size for directory."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock os.walk
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=False):
|
||||
with patch('os.walk') as mock_walk:
|
||||
mock_walk.return_value = [
|
||||
('/test_dir', [], ['file1.txt', 'file2.txt']),
|
||||
]
|
||||
|
||||
# Mock file stat
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 50
|
||||
|
||||
with patch.object(Path, 'stat', return_value=mock_stat):
|
||||
result = service._path_size(Path('/test_dir'))
|
||||
|
||||
# Verify - 2 files * 50 bytes
|
||||
assert result == 100
|
||||
|
||||
|
||||
class TestMaintenanceServiceFileCount:
|
||||
"""Tests for _file_count method."""
|
||||
|
||||
def test_file_count_nonexistent_path(self):
|
||||
"""Returns 0 for nonexistent path."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._file_count(Path('/nonexistent/path'))
|
||||
|
||||
# Verify
|
||||
assert result == 0
|
||||
|
||||
def test_file_count_single_file(self):
|
||||
"""Returns 1 for single file."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=True):
|
||||
result = service._file_count(Path('test.txt'))
|
||||
|
||||
# Verify
|
||||
assert result == 1
|
||||
|
||||
def test_file_count_directory(self):
|
||||
"""Returns file count for directory."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'is_file', return_value=False):
|
||||
with patch('os.walk') as mock_walk:
|
||||
mock_walk.return_value = [
|
||||
('/test_dir', [], ['file1.txt', 'file2.txt', 'file3.txt']),
|
||||
]
|
||||
result = service._file_count(Path('/test_dir'))
|
||||
|
||||
# Verify
|
||||
assert result == 3
|
||||
|
||||
|
||||
class TestMaintenanceServicePositiveInt:
|
||||
"""Tests for _positive_int helper method."""
|
||||
|
||||
def test_positive_int_valid_value(self):
|
||||
"""Returns valid positive integer."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(7, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 7
|
||||
assert not ap.logger.warning.called
|
||||
|
||||
def test_positive_int_invalid_string(self):
|
||||
"""Returns default for invalid string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int('invalid', 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_invalid_none(self):
|
||||
"""Returns default for None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(None, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_negative_value(self):
|
||||
"""Returns default for negative value."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(-1, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
def test_positive_int_zero_value(self):
|
||||
"""Returns default for zero value."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service._positive_int(0, 5, 'test_param')
|
||||
|
||||
# Verify
|
||||
assert result == 5
|
||||
assert ap.logger.warning.called
|
||||
|
||||
|
||||
class TestMaintenanceServiceIsUploadedFileKey:
|
||||
"""Tests for _is_uploaded_file_key helper method."""
|
||||
|
||||
def test_is_uploaded_file_key_valid(self):
|
||||
"""Returns True for valid upload file key."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - simple filename without path
|
||||
result = service._is_uploaded_file_key('uploaded_file.txt')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
def test_is_uploaded_file_key_with_path(self):
|
||||
"""Returns False for key with path separator."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - key with path
|
||||
result = service._is_uploaded_file_key('path/to/file.txt')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
def test_is_uploaded_file_key_plugin_config(self):
|
||||
"""Returns False for plugin config prefix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Execute - plugin config file
|
||||
result = service._is_uploaded_file_key('plugin_config_some_plugin.json')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMaintenanceServiceExpiredLogCandidates:
|
||||
"""Tests for _expired_log_candidates method."""
|
||||
|
||||
def test_expired_log_candidates_nonexistent_dir(self):
|
||||
"""Returns empty list when logs dir not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=False):
|
||||
result = service._expired_log_candidates(3)
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
def test_expired_log_candidates_matches_pattern(self):
|
||||
"""Matches log file pattern correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
# Mock directory with log files
|
||||
old_date = datetime.date.today() - datetime.timedelta(days=10)
|
||||
old_log_name = f'langbot-{old_date.isoformat()}.log'
|
||||
recent_log_name = f'langbot-{datetime.date.today().isoformat()}.log'
|
||||
|
||||
mock_entry_old = Mock(spec=Path)
|
||||
mock_entry_old.is_file = Mock(return_value=True)
|
||||
mock_entry_old.name = old_log_name
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 1000
|
||||
mock_entry_old.stat = Mock(return_value=mock_stat)
|
||||
|
||||
mock_entry_recent = Mock(spec=Path)
|
||||
mock_entry_recent.is_file = Mock(return_value=True)
|
||||
mock_entry_recent.name = recent_log_name
|
||||
mock_stat2 = Mock()
|
||||
mock_stat2.st_size = 500
|
||||
mock_entry_recent.stat = Mock(return_value=mock_stat2)
|
||||
|
||||
# Non-log file
|
||||
mock_entry_other = Mock(spec=Path)
|
||||
mock_entry_other.is_file = Mock(return_value=True)
|
||||
mock_entry_other.name = 'other_file.txt'
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry_old, mock_entry_recent, mock_entry_other]
|
||||
result = service._expired_log_candidates(3)
|
||||
|
||||
# Verify - only old log included
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == old_log_name
|
||||
|
||||
def test_expired_log_candidates_includes_path(self):
|
||||
"""Includes path when include_paths=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
old_date = datetime.date.today() - datetime.timedelta(days=10)
|
||||
old_log_name = f'langbot-{old_date.isoformat()}.log'
|
||||
|
||||
mock_entry = Mock(spec=Path)
|
||||
mock_entry.is_file = Mock(return_value=True)
|
||||
mock_entry.name = old_log_name
|
||||
mock_entry.__str__ = Mock(return_value='/data/logs/' + old_log_name)
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 1000
|
||||
mock_entry.stat = Mock(return_value=mock_stat)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry]
|
||||
result = service._expired_log_candidates(3, include_paths=True)
|
||||
|
||||
# Verify - path included
|
||||
assert 'path' in result[0]
|
||||
|
||||
|
||||
class TestMaintenanceServiceExpiredLocalUploadCandidates:
|
||||
"""Tests for _expired_local_upload_candidates method."""
|
||||
|
||||
def test_expired_local_upload_candidates_nonexistent_dir(self):
|
||||
"""Returns empty list when storage dir not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=False):
|
||||
result = service._expired_local_upload_candidates(7)
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
def test_expired_local_upload_candidates_filters_uploaded(self):
|
||||
"""Only returns uploaded files matching pattern."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
# Mock _is_uploaded_file_key
|
||||
service._is_uploaded_file_key = Mock(side_effect=lambda key: 'plugin_config_' not in key and '/' not in key)
|
||||
|
||||
# Create mock files - one valid, one plugin config
|
||||
mock_entry_valid = Mock(spec=Path)
|
||||
mock_entry_valid.is_file = Mock(return_value=True)
|
||||
mock_entry_valid.name = 'valid_upload.txt'
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
mock_stat.st_mtime = 0 # Very old
|
||||
mock_entry_valid.stat = Mock(return_value=mock_stat)
|
||||
|
||||
mock_entry_plugin = Mock(spec=Path)
|
||||
mock_entry_plugin.is_file = Mock(return_value=True)
|
||||
mock_entry_plugin.name = 'plugin_config_test.json'
|
||||
mock_stat2 = Mock()
|
||||
mock_stat2.st_size = 200
|
||||
mock_stat2.st_mtime = 0
|
||||
mock_entry_plugin.stat = Mock(return_value=mock_stat2)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry_valid, mock_entry_plugin]
|
||||
result = service._expired_local_upload_candidates(7)
|
||||
|
||||
# Verify - only valid upload included
|
||||
assert len(result) == 1
|
||||
assert result[0]['key'] == 'valid_upload.txt'
|
||||
|
||||
def test_expired_local_upload_candidates_includes_path(self):
|
||||
"""Includes path when include_paths=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
|
||||
service = MaintenanceService(ap)
|
||||
service._is_uploaded_file_key = Mock(return_value=True)
|
||||
|
||||
mock_entry = Mock(spec=Path)
|
||||
mock_entry.is_file = Mock(return_value=True)
|
||||
mock_entry.name = 'old_file.txt'
|
||||
mock_entry.__str__ = Mock(return_value='/data/storage/old_file.txt')
|
||||
mock_stat = Mock()
|
||||
mock_stat.st_size = 100
|
||||
mock_stat.st_mtime = 0
|
||||
mock_entry.stat = Mock(return_value=mock_stat)
|
||||
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
with patch.object(Path, 'iterdir') as mock_iterdir:
|
||||
mock_iterdir.return_value = [mock_entry]
|
||||
result = service._expired_local_upload_candidates(7, include_paths=True)
|
||||
|
||||
# Verify - path included
|
||||
assert 'path' in result[0]
|
||||
648
tests/unit_tests/api/service/test_mcp_service.py
Normal file
648
tests/unit_tests/api/service/test_mcp_service.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
Unit tests for MCPService.
|
||||
|
||||
Tests MCP server CRUD operations including:
|
||||
- MCP server listing with runtime info
|
||||
- MCP server creation with limitations
|
||||
- MCP server update with enable/disable
|
||||
- MCP server deletion
|
||||
- MCP server connection testing
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/mcp.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
|
||||
from langbot.pkg.api.http.service.mcp import MCPService
|
||||
from langbot.pkg.entity.persistence.mcp import MCPServer
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_mcp_server(
|
||||
server_uuid: str = None,
|
||||
name: str = 'Test MCP Server',
|
||||
enable: bool = True,
|
||||
mode: str = 'stdio',
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock MCPServer entity."""
|
||||
server = Mock(spec=MCPServer)
|
||||
server.uuid = server_uuid or str(uuid.uuid4())
|
||||
server.name = name
|
||||
server.enable = enable
|
||||
server.mode = mode
|
||||
server.extra_args = extra_args or {}
|
||||
return server
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestMCPServiceGetRuntimeInfo:
|
||||
"""Tests for get_runtime_info method."""
|
||||
|
||||
async def test_get_runtime_info_session_exists(self):
|
||||
"""Returns runtime info when session exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
mock_session = SimpleNamespace()
|
||||
mock_session.get_runtime_info_dict = Mock(return_value={'status': 'running', 'tools': 5})
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_info('test-server')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['status'] == 'running'
|
||||
|
||||
async def test_get_runtime_info_session_not_exists(self):
|
||||
"""Returns None when session not exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_runtime_info('nonexistent-server')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMCPServiceGetMCPServers:
|
||||
"""Tests for get_mcp_servers method."""
|
||||
|
||||
async def test_get_mcp_servers_empty_list(self):
|
||||
"""Returns empty list when no MCP servers exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_mcp_servers_returns_serialized_list(self):
|
||||
"""Returns serialized list of MCP servers."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
|
||||
server2 = _create_mock_mcp_server(server_uuid='uuid-2', name='Server 2')
|
||||
|
||||
mock_result = _create_mock_result([server1, server2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'enable': entity.enable,
|
||||
'mode': entity.mode,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Server 1'
|
||||
assert result[1]['name'] == 'Server 2'
|
||||
|
||||
async def test_get_mcp_servers_with_runtime_info(self):
|
||||
"""Returns MCP servers with runtime info when requested."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
|
||||
|
||||
mock_result = _create_mock_result([server1])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
service.get_runtime_info = AsyncMock(return_value={'status': 'connected'})
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_servers(contain_runtime_info=True)
|
||||
|
||||
# Verify - runtime info included
|
||||
assert result[0]['runtime_info'] == {'status': 'connected'}
|
||||
|
||||
|
||||
class TestMCPServiceCreateMCPServer:
|
||||
"""Tests for create_mcp_server method."""
|
||||
|
||||
async def test_create_mcp_server_max_extensions_reached_raises(self):
|
||||
"""Raises ValueError when max_extensions limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_extensions': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.plugin_connector = SimpleNamespace()
|
||||
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
|
||||
|
||||
# Mock get_mcp_servers to return 0 servers (2 plugins already)
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute & Verify - 2 plugins + new server would exceed limit
|
||||
with pytest.raises(ValueError, match='Maximum number of extensions'):
|
||||
await service.create_mcp_server({'name': 'New Server'})
|
||||
|
||||
async def test_create_mcp_server_no_limit(self):
|
||||
"""Creates MCP server without limit when max_extensions=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_extensions': -1 # No limit
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.tool_mgr = None
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
server_uuid = await service.create_mcp_server({'name': 'New Server'})
|
||||
|
||||
# Verify
|
||||
assert server_uuid is not None
|
||||
assert len(server_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_mcp_server_loads_server(self):
|
||||
"""Loads server into tool_mgr when enabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
# Create mock server entity
|
||||
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result([]) # Empty list for limit check
|
||||
elif call_count == 2:
|
||||
return Mock() # Insert
|
||||
return _create_mock_result(first_item=server_entity) # Select created
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.create_mcp_server({'name': 'New Server', 'enable': True})
|
||||
|
||||
# Verify - host_mcp_server was called
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_create_mcp_server_disabled_no_load(self):
|
||||
"""Does not load server when disabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
|
||||
ap.tool_mgr = None
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute with enable=False
|
||||
server_uuid = await service.create_mcp_server({'name': 'New Server', 'enable': False})
|
||||
|
||||
# Verify - no tool_mgr load attempt
|
||||
assert server_uuid is not None
|
||||
|
||||
|
||||
class TestMCPServiceGetMCPServerByName:
|
||||
"""Tests for get_mcp_server_by_name method."""
|
||||
|
||||
async def test_get_mcp_server_by_name_found(self):
|
||||
"""Returns MCP server when found by name."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
server = _create_mock_mcp_server(name='Found Server')
|
||||
mock_result = _create_mock_result(first_item=server)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Server',
|
||||
'runtime_info': None,
|
||||
}
|
||||
)
|
||||
ap.tool_mgr = None
|
||||
|
||||
service = MCPService(ap)
|
||||
service.get_runtime_info = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_server_by_name('Found Server')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['name'] == 'Found Server'
|
||||
|
||||
async def test_get_mcp_server_by_name_not_found(self):
|
||||
"""Returns None when MCP server not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_mcp_server_by_name('Nonexistent Server')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMCPServiceUpdateMCPServer:
|
||||
"""Tests for update_mcp_server method."""
|
||||
|
||||
async def test_update_mcp_server_disable_enabled_server(self):
|
||||
"""Removes server when disabling previously enabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
return Mock() # Update
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - disable server
|
||||
await service.update_mcp_server('test-uuid', {'enable': False})
|
||||
|
||||
# Verify - server was removed
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_enable_disabled_server(self):
|
||||
"""Loads server when enabling previously disabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {}
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=False)
|
||||
|
||||
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
elif call_count == 2:
|
||||
return Mock() # Update
|
||||
return _create_mock_result(first_item=updated_server) # Select updated
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - enable server
|
||||
await service.update_mcp_server('test-uuid', {'enable': True})
|
||||
|
||||
# Verify - server was loaded
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_update_enabled_server(self):
|
||||
"""Removes and reloads server when updating enabled server."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
# Mock for: first select -> update -> second select (for updated server)
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# All selects return the server
|
||||
return _create_mock_result(first_item=old_server)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - update enabled server (keep enabled, update extra_args)
|
||||
await service.update_mcp_server('test-uuid', {'enable': True, 'extra_args': {'new': 'args'}})
|
||||
|
||||
# Verify - remove and reload
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Old Server')
|
||||
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
|
||||
|
||||
async def test_update_mcp_server_no_tool_mgr(self):
|
||||
"""Updates persistence without tool_mgr operations."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Set mcp_tool_loader to None, not tool_mgr itself
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = None
|
||||
|
||||
old_server = _create_mock_mcp_server(name='Server', enable=True)
|
||||
|
||||
# Mock execute for select and update
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=old_server)
|
||||
return Mock() # Update
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.update_mcp_server('test-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify - persistence was called
|
||||
assert ap.persistence_mgr.execute_async.call_count >= 2
|
||||
|
||||
|
||||
class TestMCPServiceDeleteMCPServer:
|
||||
"""Tests for delete_mcp_server method."""
|
||||
|
||||
async def test_delete_mcp_server_calls_remove_and_delete(self):
|
||||
"""Calls both persistence delete and tool_mgr remove."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {'Server to Delete': Mock()}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
server = _create_mock_mcp_server(name='Server to Delete')
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=server)
|
||||
return Mock() # Delete
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_mcp_server('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Server to Delete')
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_delete_mcp_server_not_in_sessions(self):
|
||||
"""Does not attempt remove if server not in sessions."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {} # Server not in sessions
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
server = _create_mock_mcp_server(name='Not in Sessions')
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=server)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_mcp_server('test-uuid')
|
||||
|
||||
# Verify - remove not called (server not in sessions)
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_not_called()
|
||||
|
||||
async def test_delete_mcp_server_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.sessions = {}
|
||||
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
|
||||
|
||||
# No server found
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=None)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_mcp_server('nonexistent-uuid')
|
||||
|
||||
# Verify - delete was called regardless
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
|
||||
class TestMCPServiceTestMCPServer:
|
||||
"""Tests for test_mcp_server method."""
|
||||
|
||||
async def test_test_mcp_server_existing_server(self):
|
||||
"""Tests existing MCP server connection."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
from langbot.pkg.provider.tools.loaders.mcp import MCPSessionStatus
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.status = MCPSessionStatus.ERROR
|
||||
mock_session.start = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=123)
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute
|
||||
task_id = await service.test_mcp_server('existing-server', {})
|
||||
|
||||
# Verify - returns task ID
|
||||
assert task_id == 123
|
||||
|
||||
async def test_test_mcp_server_not_found_raises(self):
|
||||
"""Raises ValueError when server not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Server not found'):
|
||||
await service.test_mcp_server('nonexistent-server', {})
|
||||
|
||||
async def test_test_mcp_server_new_server(self):
|
||||
"""Tests new MCP server with underscore name."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.tool_mgr = SimpleNamespace()
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.start = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=456)
|
||||
)
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
# Execute with '_' name (new server)
|
||||
task_id = await service.test_mcp_server('_', {'name': 'New Server'})
|
||||
|
||||
# Verify - load_mcp_server called
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
||||
assert task_id == 456
|
||||
964
tests/unit_tests/api/service/test_model_service.py
Normal file
964
tests/unit_tests/api/service/test_model_service.py
Normal file
@@ -0,0 +1,964 @@
|
||||
"""
|
||||
Unit tests for LLMModelsService, EmbeddingModelsService, and RerankModelsService.
|
||||
|
||||
Tests model management operations including:
|
||||
- Model CRUD operations
|
||||
- Model with provider info
|
||||
- Provider auto-creation on model create/update
|
||||
- Runtime model loading/unloading
|
||||
- Model deletion
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/model.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.model import (
|
||||
LLMModelsService,
|
||||
EmbeddingModelsService,
|
||||
RerankModelsService,
|
||||
_parse_provider_api_keys,
|
||||
_runtime_model_data,
|
||||
)
|
||||
from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, RerankModel, ModelProvider
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_llm_model(
|
||||
model_uuid: str = 'llm-uuid',
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
abilities: list = None,
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
model = Mock(spec=LLMModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.abilities = abilities or []
|
||||
model.extra_args = extra_args or {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_embedding_model(
|
||||
model_uuid: str = 'embedding-uuid',
|
||||
name: str = 'Test Embedding',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock EmbeddingModel entity."""
|
||||
model = Mock(spec=EmbeddingModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.extra_args = {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_rerank_model(
|
||||
model_uuid: str = 'rerank-uuid',
|
||||
name: str = 'Test Rerank',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock RerankModel entity."""
|
||||
model = Mock(spec=RerankModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.extra_args = {}
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_provider(
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
name: str = 'Test Provider',
|
||||
api_keys: list = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock ModelProvider entity."""
|
||||
provider = Mock(spec=ModelProvider)
|
||||
provider.uuid = provider_uuid
|
||||
provider.name = name
|
||||
provider.requester = 'openai'
|
||||
provider.base_url = 'https://api.openai.com'
|
||||
provider.api_keys = api_keys or ['key']
|
||||
return provider
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestParseProviderApiKeys:
|
||||
"""Tests for _parse_provider_api_keys helper function."""
|
||||
|
||||
def test_parse_valid_json_string(self):
|
||||
"""Parses valid JSON string to list."""
|
||||
provider_dict = {'api_keys': '["key1", "key2"]'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == ['key1', 'key2']
|
||||
|
||||
def test_parse_invalid_json_returns_empty(self):
|
||||
"""Returns empty list for invalid JSON."""
|
||||
provider_dict = {'api_keys': 'invalid json'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == []
|
||||
|
||||
def test_parse_already_list(self):
|
||||
"""Returns unchanged if already a list."""
|
||||
provider_dict = {'api_keys': ['key1', 'key2']}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert result['api_keys'] == ['key1', 'key2']
|
||||
|
||||
def test_parse_missing_key(self):
|
||||
"""Handles missing api_keys key."""
|
||||
provider_dict = {'name': 'Provider'}
|
||||
result = _parse_provider_api_keys(provider_dict)
|
||||
assert 'api_keys' not in result
|
||||
|
||||
|
||||
class TestRuntimeModelData:
|
||||
"""Tests for _runtime_model_data helper function."""
|
||||
|
||||
def test_runtime_data_preserves_uuid(self):
|
||||
"""Preserves UUID in runtime data."""
|
||||
update_payload = {'name': 'Updated', 'provider_uuid': 'provider'}
|
||||
result = _runtime_model_data('model-uuid', update_payload)
|
||||
assert result['uuid'] == 'model-uuid'
|
||||
assert result['name'] == 'Updated'
|
||||
|
||||
def test_runtime_data_copies_all_fields(self):
|
||||
"""Copies all fields from payload."""
|
||||
update_payload = {
|
||||
'name': 'Model',
|
||||
'provider_uuid': 'provider',
|
||||
'abilities': ['vision'],
|
||||
'extra_args': {'temp': 0.7},
|
||||
}
|
||||
result = _runtime_model_data('uuid', update_payload)
|
||||
assert result['abilities'] == ['vision']
|
||||
assert result['extra_args'] == {'temp': 0.7}
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModels:
|
||||
"""Tests for LLMModelsService.get_llm_models method."""
|
||||
|
||||
async def test_get_llm_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
mock_provider_result = _create_mock_result([])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
return mock_result if call_count == 0 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_llm_models_with_provider_info(self):
|
||||
"""Returns models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == 'Test LLM'
|
||||
|
||||
async def test_get_llm_models_hide_secret_keys(self):
|
||||
"""Hides secret API keys when include_secret=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
provider = _create_mock_provider(api_keys=['secret-key-1', 'secret-key-2'])
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models(include_secret=False)
|
||||
|
||||
# Verify - keys should be masked
|
||||
assert result[0]['provider']['api_keys'] == ['***', '***']
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModel:
|
||||
"""Tests for LLMModelsService.get_llm_model method."""
|
||||
|
||||
async def test_get_llm_model_found(self):
|
||||
"""Returns model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Test LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
|
||||
}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_model('found-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
|
||||
async def test_get_llm_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_model('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLLMModelsServiceGetLLMModelsByProvider:
|
||||
"""Tests for LLMModelsService.get_llm_models_by_provider method."""
|
||||
|
||||
async def test_get_models_by_provider_uuid(self):
|
||||
"""Returns models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_llm_model(model_uuid='model-1', provider_uuid='target-provider')
|
||||
model2 = _create_mock_llm_model(model_uuid='model-2', provider_uuid='target-provider')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'model-1', 'name': 'Model 1'}
|
||||
)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_llm_models_by_provider('target-provider')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestLLMModelsServiceCreateLLMModel:
|
||||
"""Tests for LLMModelsService.create_llm_model method."""
|
||||
|
||||
async def test_create_llm_model_generates_uuid(self):
|
||||
"""Creates LLM model with generated UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'name': 'New LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
assert len(model_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_llm_model_preserve_uuid(self):
|
||||
"""Creates LLM model preserving provided UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'uuid': 'preserved-uuid',
|
||||
'name': 'Preserved UUID Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}, preserve_uuid=True)
|
||||
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty - no provider
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_llm_model({
|
||||
'name': 'No Provider Model',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
async def test_create_llm_model_with_provider_data(self):
|
||||
"""Creates provider when provider data provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.find_or_create_provider = AsyncMock(return_value='new-provider-uuid')
|
||||
ap.pipeline_service = SimpleNamespace()
|
||||
ap.pipeline_service.update_pipeline = AsyncMock()
|
||||
|
||||
# Create runtime provider
|
||||
runtime_provider = Mock()
|
||||
ap.model_mgr.provider_dict['new-provider-uuid'] = runtime_provider
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute - with provider data (no UUID)
|
||||
result_uuid = await service.create_llm_model({
|
||||
'name': 'Model with New Provider',
|
||||
'provider': {
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
},
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify - provider_service was called and UUID generated
|
||||
ap.provider_service.find_or_create_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestLLMModelsServiceUpdateLLMModel:
|
||||
"""Tests for LLMModelsService.update_llm_model method."""
|
||||
|
||||
async def test_update_llm_model_removes_uuid_from_data(self):
|
||||
"""Removes uuid from update data before persisting."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_llm_model('existing-uuid', {
|
||||
'uuid': 'should-be-removed',
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
})
|
||||
|
||||
# Verify - remove and load called
|
||||
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
|
||||
|
||||
async def test_update_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found after update."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.update_llm_model('model-uuid', {
|
||||
'name': 'Update',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
|
||||
|
||||
class TestLLMModelsServiceDeleteLLMModel:
|
||||
"""Tests for LLMModelsService.delete_llm_model method."""
|
||||
|
||||
async def test_delete_llm_model_success(self):
|
||||
"""Deletes LLM model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_llm_model('delete-uuid')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
ap.model_mgr.remove_llm_model.assert_called_once_with('delete-uuid')
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModels:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_models method."""
|
||||
|
||||
async def test_get_embedding_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_embedding_models_with_provider(self):
|
||||
"""Returns embedding models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_embedding_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'api_keys': getattr(entity, 'api_keys', ['key']),
|
||||
}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_model method."""
|
||||
|
||||
async def test_get_embedding_model_found(self):
|
||||
"""Returns embedding model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_embedding_model(model_uuid='found-embedding')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-embedding',
|
||||
'name': 'Found Embedding',
|
||||
'provider': {'uuid': 'provider-uuid'},
|
||||
}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_model('found-embedding')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
|
||||
async def test_get_embedding_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_model('nonexistent-embedding')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.create_embedding_model method."""
|
||||
|
||||
async def test_create_embedding_model_success(self):
|
||||
"""Creates embedding model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.embedding_models = []
|
||||
ap.model_mgr.load_embedding_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_embedding_model({
|
||||
'name': 'New Embedding',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
assert len(model_uuid) == 36
|
||||
|
||||
async def test_create_embedding_model_provider_not_found_raises(self):
|
||||
"""Raises Exception when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {} # Empty
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_embedding_model({
|
||||
'name': 'No Provider Embedding',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
|
||||
"""Tests for EmbeddingModelsService.delete_embedding_model method."""
|
||||
|
||||
async def test_delete_embedding_model_success(self):
|
||||
"""Deletes embedding model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_embedding_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_embedding_model('delete-embedding-uuid')
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.remove_embedding_model.assert_called_once()
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModels:
|
||||
"""Tests for RerankModelsService.get_rerank_models method."""
|
||||
|
||||
async def test_get_rerank_models_empty_list(self):
|
||||
"""Returns empty list when no models exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_rerank_models_with_provider(self):
|
||||
"""Returns rerank models with provider info."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_rerank_model()
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'api_keys': getattr(entity, 'api_keys', ['key']),
|
||||
}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModel:
|
||||
"""Tests for RerankModelsService.get_rerank_model method."""
|
||||
|
||||
async def test_get_rerank_model_found(self):
|
||||
"""Returns rerank model when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_rerank_model(model_uuid='found-rerank')
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_model_result if call_count == 1 else mock_provider_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-rerank',
|
||||
'name': 'Found Rerank',
|
||||
'provider': {'uuid': 'provider-uuid'},
|
||||
}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_model('found-rerank')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
|
||||
async def test_get_rerank_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_model('nonexistent-rerank')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRerankModelsServiceCreateRerankModel:
|
||||
"""Tests for RerankModelsService.create_rerank_model method."""
|
||||
|
||||
async def test_create_rerank_model_success(self):
|
||||
"""Creates rerank model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.rerank_models = []
|
||||
ap.model_mgr.load_rerank_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_rerank_model({
|
||||
'name': 'New Rerank',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
|
||||
async def test_create_rerank_model_provider_not_found_raises(self):
|
||||
"""Raises Exception when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_rerank_model({
|
||||
'name': 'No Provider Rerank',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
|
||||
|
||||
class TestRerankModelsServiceDeleteRerankModel:
|
||||
"""Tests for RerankModelsService.delete_rerank_model method."""
|
||||
|
||||
async def test_delete_rerank_model_success(self):
|
||||
"""Deletes rerank model successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_rerank_model = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_rerank_model('delete-rerank-uuid')
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.remove_rerank_model.assert_called_once()
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
|
||||
"""Tests for EmbeddingModelsService.get_embedding_models_by_provider method."""
|
||||
|
||||
async def test_get_embedding_models_by_provider_uuid(self):
|
||||
"""Returns embedding models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_embedding_model(model_uuid='emb-1', provider_uuid='provider-uuid')
|
||||
model2 = _create_mock_embedding_model(model_uuid='emb-2', provider_uuid='provider-uuid')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
|
||||
)
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_embedding_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
"""Tests for RerankModelsService.get_rerank_models_by_provider method."""
|
||||
|
||||
async def test_get_rerank_models_by_provider_uuid(self):
|
||||
"""Returns rerank models for specific provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model1 = _create_mock_rerank_model(model_uuid='rerank-1', provider_uuid='provider-uuid')
|
||||
model2 = _create_mock_rerank_model(model_uuid='rerank-2', provider_uuid='provider-uuid')
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
|
||||
)
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_rerank_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
831
tests/unit_tests/api/service/test_pipeline_service.py
Normal file
831
tests/unit_tests/api/service/test_pipeline_service.py
Normal file
@@ -0,0 +1,831 @@
|
||||
"""
|
||||
Unit tests for PipelineService.
|
||||
|
||||
Tests pipeline CRUD operations including:
|
||||
- Pipeline listing with sorting
|
||||
- Pipeline creation with default config
|
||||
- Pipeline update with bot sync
|
||||
- Pipeline copy functionality
|
||||
- Extensions preferences management
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/pipeline.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, mock_open
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from langbot.pkg.api.http.service.pipeline import PipelineService, default_stage_order
|
||||
from langbot.pkg.entity.persistence.pipeline import LegacyPipeline
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_pipeline(
|
||||
pipeline_uuid: str = None,
|
||||
name: str = 'Test Pipeline',
|
||||
description: str = 'Test Description',
|
||||
is_default: bool = False,
|
||||
stages: list = None,
|
||||
config: dict = None,
|
||||
extensions_preferences: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LegacyPipeline entity."""
|
||||
pipeline = Mock(spec=LegacyPipeline)
|
||||
pipeline.uuid = pipeline_uuid or str(uuid.uuid4())
|
||||
pipeline.name = name
|
||||
pipeline.description = description
|
||||
pipeline.emoji = '⚙️'
|
||||
pipeline.is_default = is_default
|
||||
pipeline.for_version = '1.0.0'
|
||||
pipeline.stages = stages or default_stage_order.copy()
|
||||
pipeline.config = config or {}
|
||||
pipeline.extensions_preferences = extensions_preferences or {
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': [],
|
||||
}
|
||||
return pipeline
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipelineMetadata:
|
||||
"""Tests for get_pipeline_metadata method."""
|
||||
|
||||
async def test_get_pipeline_metadata_returns_list(self):
|
||||
"""Returns list of pipeline metadata configs."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.pipeline_config_meta_trigger = {'trigger': {}}
|
||||
ap.pipeline_config_meta_safety = {'safety': {}}
|
||||
ap.pipeline_config_meta_ai = {'ai': {}}
|
||||
ap.pipeline_config_meta_output = {'output': {}}
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline_metadata()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 4
|
||||
assert 'trigger' in result[0]
|
||||
assert 'safety' in result[1]
|
||||
assert 'ai' in result[2]
|
||||
assert 'output' in result[3]
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipelines:
|
||||
"""Tests for get_pipelines method."""
|
||||
|
||||
async def test_get_pipelines_empty_list(self):
|
||||
"""Returns empty list when no pipelines exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipelines()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_pipelines_returns_sorted_by_created_at_desc(self):
|
||||
"""Returns pipelines sorted by created_at descending by default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
pipeline1 = _create_mock_pipeline(pipeline_uuid='uuid-1', name='Pipeline 1')
|
||||
pipeline2 = _create_mock_pipeline(pipeline_uuid='uuid-2', name='Pipeline 2')
|
||||
|
||||
mock_result = _create_mock_result([pipeline1, pipeline2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipelines()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_get_pipelines_sort_by_updated_at_asc(self):
|
||||
"""Returns pipelines sorted by updated_at ascending."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
await service.get_pipelines(sort_by='updated_at', sort_order='ASC')
|
||||
|
||||
# Verify - execute was called with sort parameters
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceGetPipeline:
|
||||
"""Tests for get_pipeline method."""
|
||||
|
||||
async def test_get_pipeline_by_uuid_found(self):
|
||||
"""Returns pipeline when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
pipeline = _create_mock_pipeline(pipeline_uuid='test-uuid', name='Found Pipeline')
|
||||
mock_result = _create_mock_result(first_item=pipeline)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'name': 'Found Pipeline',
|
||||
'stages': default_stage_order,
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline('test-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'test-uuid'
|
||||
assert result['name'] == 'Found Pipeline'
|
||||
|
||||
async def test_get_pipeline_by_uuid_not_found(self):
|
||||
"""Returns None when pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_pipeline('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPipelineServiceCreatePipeline:
|
||||
"""Tests for create_pipeline method."""
|
||||
|
||||
async def test_create_pipeline_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_pipelines limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||
await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
async def test_create_pipeline_no_limit(self):
|
||||
"""Creates pipeline without limit when max_pipelines=-1."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
# Override get_pipelines to return empty list (no limit check issue)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
|
||||
|
||||
# Mock persistence for insert
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
|
||||
)
|
||||
|
||||
# Mock the file read for default config - patch at the utils module level
|
||||
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
# Verify
|
||||
assert bot_uuid is not None
|
||||
assert len(bot_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_pipeline_as_default(self):
|
||||
"""Creates pipeline with is_default=True."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
|
||||
)
|
||||
|
||||
# Mock the file read
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
|
||||
|
||||
# Verify - execute was called
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_create_pipeline_sets_default_extensions_preferences(self):
|
||||
"""Sets default extensions_preferences when not provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
})
|
||||
|
||||
insert_params = []
|
||||
|
||||
async def mock_execute(query):
|
||||
params = query.compile().params
|
||||
if 'extensions_preferences' in params:
|
||||
insert_params.append(params)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
}
|
||||
)
|
||||
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
assert len(insert_params) == 1
|
||||
assert insert_params[0]['extensions_preferences'] == {
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': [],
|
||||
}
|
||||
|
||||
|
||||
class _MockResultWithBots:
|
||||
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
|
||||
def __init__(self, bots_list):
|
||||
self._bots_list = bots_list
|
||||
|
||||
def all(self):
|
||||
return self._bots_list
|
||||
|
||||
def first(self):
|
||||
return self._bots_list[0] if self._bots_list else None
|
||||
|
||||
|
||||
class TestPipelineServiceUpdatePipeline:
|
||||
"""Tests for update_pipeline method."""
|
||||
|
||||
async def test_update_pipeline_removes_protected_fields(self):
|
||||
"""Does not persist protected fields from update data."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
ap.bot_service = None # No bot_service when not updating name
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
|
||||
|
||||
# Execute with protected fields - no name change, so no bot sync
|
||||
pipeline_data = {
|
||||
'uuid': 'should-be-removed',
|
||||
'for_version': 'should-be-removed',
|
||||
'stages': ['should-be-removed'],
|
||||
'is_default': True,
|
||||
'description': 'New description', # Not name change, so no bot_service needed
|
||||
}
|
||||
await service.update_pipeline('test-uuid', pipeline_data)
|
||||
|
||||
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
|
||||
assert update_params['description'] == 'New description'
|
||||
assert 'should-be-removed' not in update_params.values()
|
||||
assert ['should-be-removed'] not in update_params.values()
|
||||
assert not any(value is True for value in update_params.values())
|
||||
|
||||
async def test_update_pipeline_syncs_bot_names(self):
|
||||
"""Updates bot use_pipeline_name when pipeline name changes."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
ap.sess_mgr.session_list = []
|
||||
ap.bot_service = SimpleNamespace()
|
||||
ap.bot_service.update_bot = AsyncMock()
|
||||
|
||||
# Create proper mock Bot entities with uuid attribute
|
||||
mock_bot1 = Mock()
|
||||
mock_bot1.uuid = 'bot-uuid-1'
|
||||
mock_bot2 = Mock()
|
||||
mock_bot2.uuid = 'bot-uuid-2'
|
||||
|
||||
# Create bot list
|
||||
bot_list = [mock_bot1, mock_bot2]
|
||||
|
||||
# Create mock result using helper class
|
||||
bot_result = _MockResultWithBots(bot_list)
|
||||
|
||||
# The order of calls in update_pipeline:
|
||||
# 1. UPDATE (line 125) - returns Mock (no result needed)
|
||||
# 2. SELECT bots (line 136) - returns bot_result with .all()
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First call is the UPDATE - just return a Mock
|
||||
return Mock()
|
||||
elif call_count == 2:
|
||||
# Second call is the SELECT bots - return proper result
|
||||
return bot_result
|
||||
return Mock() # Any additional calls
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'New Name'})
|
||||
|
||||
# Execute with name change
|
||||
await service.update_pipeline('test-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify - bot_service.update_bot was called for each bot
|
||||
assert ap.bot_service.update_bot.call_count == 2
|
||||
|
||||
async def test_update_pipeline_clears_conversations(self):
|
||||
"""Clears session conversations using this pipeline."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.sess_mgr = SimpleNamespace()
|
||||
|
||||
# Mock session with conversation using this pipeline
|
||||
session = SimpleNamespace()
|
||||
session.using_conversation = SimpleNamespace()
|
||||
session.using_conversation.pipeline_uuid = 'test-uuid'
|
||||
ap.sess_mgr.session_list = [session]
|
||||
ap.bot_service = SimpleNamespace()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid'})
|
||||
|
||||
# Execute
|
||||
await service.update_pipeline('test-uuid', {'description': 'Updated'})
|
||||
|
||||
# Verify - conversation was cleared
|
||||
assert session.using_conversation is None
|
||||
|
||||
|
||||
class TestPipelineServiceDeletePipeline:
|
||||
"""Tests for delete_pipeline method."""
|
||||
|
||||
async def test_delete_pipeline_calls_remove_and_delete(self):
|
||||
"""Calls both pipeline_mgr.remove_pipeline and persistence delete."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_pipeline('test-uuid')
|
||||
|
||||
# Verify
|
||||
ap.pipeline_mgr.remove_pipeline.assert_called_once_with('test-uuid')
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_pipeline_nonexistent_uuid(self):
|
||||
"""Delete operation completes even for nonexistent UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_pipeline('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
ap.pipeline_mgr.remove_pipeline.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceCopyPipeline:
|
||||
"""Tests for copy_pipeline method."""
|
||||
|
||||
async def test_copy_pipeline_max_limit_reached_raises(self):
|
||||
"""Raises ValueError when max_pipelines limit reached."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
# Mock get_pipelines to return 2 pipelines
|
||||
service.get_pipelines = AsyncMock(return_value=[
|
||||
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
||||
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
||||
])
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
async def test_copy_pipeline_not_found_raises(self):
|
||||
"""Raises ValueError when original pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
|
||||
ap.persistence_mgr.execute_async = AsyncMock(
|
||||
return_value=_create_mock_result(first_item=None) # Original not found
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Pipeline original-uuid not found'):
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
async def test_copy_pipeline_creates_copy(self):
|
||||
"""Creates a copy with (Copy) suffix."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
original = _create_mock_pipeline(
|
||||
pipeline_uuid='original-uuid',
|
||||
name='Original Pipeline',
|
||||
description='Original description',
|
||||
stages=['Stage1', 'Stage2'],
|
||||
config={'key': 'value'},
|
||||
extensions_preferences={'enable_all_plugins': False, 'plugins': ['plugin1']},
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
|
||||
|
||||
# Mock persistence - get original, then insert, then get new
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'new-copy-uuid',
|
||||
'name': 'Original Pipeline (Copy)',
|
||||
}
|
||||
)
|
||||
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'new-copy-uuid',
|
||||
'name': 'Original Pipeline (Copy)',
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
new_uuid = await service.copy_pipeline('original-uuid')
|
||||
|
||||
# Verify
|
||||
assert new_uuid is not None
|
||||
assert len(new_uuid) == 36 # UUID format
|
||||
|
||||
async def test_copy_pipeline_is_not_default(self):
|
||||
"""Copy is never set as default."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
|
||||
|
||||
# Original is default
|
||||
original = _create_mock_pipeline(
|
||||
pipeline_uuid='original-uuid',
|
||||
name='Default Pipeline',
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'copy-uuid', 'is_default': False}
|
||||
)
|
||||
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
||||
|
||||
# Execute
|
||||
await service.copy_pipeline('original-uuid')
|
||||
|
||||
# Verify - pipeline_mgr.load_pipeline called (copy created)
|
||||
ap.pipeline_mgr.load_pipeline.assert_called_once()
|
||||
|
||||
|
||||
class TestPipelineServiceUpdatePipelineExtensions:
|
||||
"""Tests for update_pipeline_extensions method."""
|
||||
|
||||
async def test_update_extensions_pipeline_not_found_raises(self):
|
||||
"""Raises ValueError when pipeline not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Pipeline nonexistent-uuid not found'):
|
||||
await service.update_pipeline_extensions('nonexistent-uuid', [])
|
||||
|
||||
async def test_update_extensions_sets_plugins(self):
|
||||
"""Updates plugins in extensions_preferences."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline(
|
||||
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
bound_plugins = [{'plugin_uuid': 'plugin-1'}]
|
||||
await service.update_pipeline_extensions(
|
||||
'test-uuid',
|
||||
bound_plugins=bound_plugins,
|
||||
enable_all_plugins=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_update_extensions_sets_mcp_servers(self):
|
||||
"""Updates MCP servers in extensions_preferences."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline()
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {
|
||||
'enable_all_mcp_servers': False,
|
||||
'mcp_servers': ['mcp-server-1'],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-uuid',
|
||||
'extensions_preferences': {'mcp_servers': ['mcp-server-1']},
|
||||
}
|
||||
)
|
||||
|
||||
# Execute
|
||||
await service.update_pipeline_extensions(
|
||||
'test-uuid',
|
||||
bound_plugins=[],
|
||||
bound_mcp_servers=['mcp-server-1'],
|
||||
enable_all_mcp_servers=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
async def test_update_extensions_none_mcp_servers_keeps_existing(self):
|
||||
"""Does not modify mcp_servers when bound_mcp_servers is None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline(
|
||||
extensions_preferences={
|
||||
'enable_all_plugins': True,
|
||||
'enable_all_mcp_servers': True,
|
||||
'plugins': [],
|
||||
'mcp_servers': ['existing-server'],
|
||||
}
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _create_mock_result(first_item=original_pipeline)
|
||||
return Mock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
|
||||
)
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
|
||||
)
|
||||
|
||||
# Execute - bound_mcp_servers is None (not provided)
|
||||
await service.update_pipeline_extensions('test-uuid', bound_plugins=[])
|
||||
|
||||
# Verify - persistence was called
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
|
||||
|
||||
class TestDefaultStageOrder:
|
||||
"""Tests for default_stage_order constant."""
|
||||
|
||||
def test_default_stage_order_not_empty(self):
|
||||
"""Default stage order is not empty."""
|
||||
assert len(default_stage_order) > 0
|
||||
|
||||
def test_default_stage_order_contains_key_stages(self):
|
||||
"""Default stage order contains key processing stages."""
|
||||
assert 'MessageProcessor' in default_stage_order
|
||||
assert 'SendResponseBackStage' in default_stage_order
|
||||
866
tests/unit_tests/api/service/test_provider_service.py
Normal file
866
tests/unit_tests/api/service/test_provider_service.py
Normal file
@@ -0,0 +1,866 @@
|
||||
"""
|
||||
Unit tests for ModelProviderService.
|
||||
|
||||
Tests model provider management operations including:
|
||||
- Provider CRUD operations
|
||||
- Provider model count checking
|
||||
- Find or create provider logic
|
||||
- Space model provider API key updates
|
||||
- Provider model scanning
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/provider.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.provider import ModelProviderService
|
||||
from langbot.pkg.entity.persistence.model import ModelProvider, LLMModel, EmbeddingModel, RerankModel
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_provider(
|
||||
provider_uuid: str = 'test-provider-uuid',
|
||||
name: str = 'Test Provider',
|
||||
requester: str = 'openai',
|
||||
base_url: str = 'https://api.openai.com',
|
||||
api_keys: list = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock ModelProvider entity."""
|
||||
provider = Mock(spec=ModelProvider)
|
||||
provider.uuid = provider_uuid
|
||||
provider.name = name
|
||||
provider.requester = requester
|
||||
provider.base_url = base_url
|
||||
provider.api_keys = api_keys or ['test-key']
|
||||
return provider
|
||||
|
||||
|
||||
def _create_mock_llm_model(
|
||||
model_uuid: str = 'test-llm-uuid',
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'test-provider-uuid',
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
model = Mock(spec=LLMModel)
|
||||
model.uuid = model_uuid
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
return model
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
result.scalar = Mock(return_value=len(items) if items else 0)
|
||||
return result
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProviders:
|
||||
"""Tests for get_providers method."""
|
||||
|
||||
async def test_get_providers_empty_list(self):
|
||||
"""Returns empty list when no providers exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'requester': entity.requester,
|
||||
'base_url': entity.base_url,
|
||||
'api_keys': entity.api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_providers_returns_serialized_list(self):
|
||||
"""Returns serialized list of providers."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider1 = _create_mock_provider(provider_uuid='provider-1', name='Provider 1')
|
||||
provider2 = _create_mock_provider(provider_uuid='provider-2', name='Provider 2')
|
||||
|
||||
mock_result = _create_mock_result([provider1, provider2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'requester': entity.requester,
|
||||
'base_url': entity.base_url,
|
||||
'api_keys': entity.api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Provider 1'
|
||||
assert result[1]['name'] == 'Provider 2'
|
||||
|
||||
async def test_get_providers_parse_api_keys_json_string(self):
|
||||
"""Parses api_keys from JSON string if needed."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='["key1", "key2"]')
|
||||
|
||||
mock_result = _create_mock_result([provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'api_keys': entity.api_keys, # Returns string
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify - api_keys should be parsed from string
|
||||
assert result[0]['api_keys'] == ['key1', 'key2']
|
||||
|
||||
async def test_get_providers_invalid_json_api_keys_returns_empty(self):
|
||||
"""Returns empty list for invalid JSON api_keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='invalid-json')
|
||||
|
||||
mock_result = _create_mock_result([provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'api_keys': entity.api_keys, # Returns invalid string
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_providers()
|
||||
|
||||
# Verify - invalid JSON returns empty list
|
||||
assert result[0]['api_keys'] == []
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProvider:
|
||||
"""Tests for get_provider method."""
|
||||
|
||||
async def test_get_provider_by_uuid_found(self):
|
||||
"""Returns provider when found by UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='found-uuid', name='Found Provider')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Found Provider',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider('found-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
|
||||
async def test_get_provider_by_uuid_not_found(self):
|
||||
"""Returns None when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestModelProviderServiceCreateProvider:
|
||||
"""Tests for create_provider method."""
|
||||
|
||||
async def test_create_provider_generates_uuid(self):
|
||||
"""Creates provider with generated UUID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
# Mock load_provider to return runtime provider
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'generated-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
provider_uuid = await service.create_provider({
|
||||
'name': 'New Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
|
||||
# Verify - UUID is generated
|
||||
assert provider_uuid is not None
|
||||
assert len(provider_uuid) == 36 # UUID format
|
||||
|
||||
async def test_create_provider_loads_to_runtime(self):
|
||||
"""Loads provider to runtime model_mgr."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'runtime-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result_uuid = await service.create_provider({
|
||||
'name': 'Runtime Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
|
||||
# Verify - provider added to runtime dict and UUID generated
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestModelProviderServiceUpdateProvider:
|
||||
"""Tests for update_provider method."""
|
||||
|
||||
async def test_update_provider_removes_uuid_from_data(self):
|
||||
"""Removes uuid from update data before persisting."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_provider('existing-uuid', {
|
||||
'uuid': 'should-be-removed', # Will be removed
|
||||
'name': 'Updated Name',
|
||||
})
|
||||
|
||||
# Verify - reload called
|
||||
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
|
||||
|
||||
async def test_update_provider_reloads_runtime(self):
|
||||
"""Reloads provider in runtime after update."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_provider('update-uuid', {'name': 'New Name'})
|
||||
|
||||
# Verify
|
||||
ap.model_mgr.reload_provider.assert_called_once()
|
||||
|
||||
|
||||
class TestModelProviderServiceDeleteProvider:
|
||||
"""Tests for delete_provider method."""
|
||||
|
||||
async def test_delete_provider_with_llm_models_raises_error(self):
|
||||
"""Raises ValueError when LLM models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock LLM model exists - only return LLM result since that's first check
|
||||
llm_result = _create_mock_result([], first_item=_create_mock_llm_model())
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=llm_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: LLM models'):
|
||||
await service.delete_provider('provider-with-llm')
|
||||
|
||||
async def test_delete_provider_with_embedding_models_raises_error(self):
|
||||
"""Raises ValueError when Embedding models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create results for each check type
|
||||
llm_result = Mock()
|
||||
llm_result.first = Mock(return_value=None) # No LLM models
|
||||
embedding_result = Mock()
|
||||
embedding_result.first = Mock(return_value=Mock(spec=EmbeddingModel)) # Has embedding model
|
||||
rerank_result = Mock()
|
||||
rerank_result.first = Mock(return_value=None)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify - should raise embedding error (LLM check passes, embedding check fails)
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: Embedding models'):
|
||||
await service.delete_provider('provider-with-embedding')
|
||||
|
||||
async def test_delete_provider_with_rerank_models_raises_error(self):
|
||||
"""Raises ValueError when Rerank models reference provider."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Create results for each check type
|
||||
llm_result = Mock()
|
||||
llm_result.first = Mock(return_value=None) # No LLM models
|
||||
embedding_result = Mock()
|
||||
embedding_result.first = Mock(return_value=None) # No embedding models
|
||||
rerank_result = Mock()
|
||||
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify - should raise rerank error (LLM and embedding checks pass, rerank check fails)
|
||||
with pytest.raises(ValueError, match='Cannot delete provider: Rerank models'):
|
||||
await service.delete_provider('provider-with-rerank')
|
||||
|
||||
async def test_delete_provider_no_models_success(self):
|
||||
"""Deletes provider when no models reference it."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.remove_provider = AsyncMock()
|
||||
|
||||
# Mock no models reference provider
|
||||
empty_result = Mock()
|
||||
empty_result.first = Mock(return_value=None)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=empty_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_provider('provider-no-models')
|
||||
|
||||
# Verify - delete and remove called
|
||||
ap.model_mgr.remove_provider.assert_called_once_with('provider-no-models')
|
||||
|
||||
|
||||
class TestModelProviderServiceGetProviderModelCounts:
|
||||
"""Tests for get_provider_model_counts method."""
|
||||
|
||||
async def test_get_model_counts_returns_correct_counts(self):
|
||||
"""Returns correct counts for each model type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock scalar results for counts
|
||||
llm_result = Mock()
|
||||
llm_result.scalar = Mock(return_value=3)
|
||||
embedding_result = Mock()
|
||||
embedding_result.scalar = Mock(return_value=2)
|
||||
rerank_result = Mock()
|
||||
rerank_result.scalar = Mock(return_value=1)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return llm_result
|
||||
elif call_count == 2:
|
||||
return embedding_result
|
||||
return rerank_result
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider_model_counts('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert result['llm_count'] == 3
|
||||
assert result['embedding_count'] == 2
|
||||
assert result['rerank_count'] == 1
|
||||
|
||||
async def test_get_model_counts_zero_counts(self):
|
||||
"""Returns zero counts when no models."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
zero_result = Mock()
|
||||
zero_result.scalar = Mock(return_value=0)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=zero_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_provider_model_counts('empty-provider')
|
||||
|
||||
# Verify
|
||||
assert result['llm_count'] == 0
|
||||
assert result['embedding_count'] == 0
|
||||
assert result['rerank_count'] == 0
|
||||
|
||||
|
||||
class TestModelProviderServiceFindOrCreateProvider:
|
||||
"""Tests for find_or_create_provider method."""
|
||||
|
||||
async def test_find_existing_provider_matching_config(self):
|
||||
"""Returns existing provider UUID when config matches."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
existing_provider = _create_mock_provider(
|
||||
provider_uuid='existing-uuid',
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'],
|
||||
)
|
||||
|
||||
mock_result = _create_mock_result([existing_provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.find_or_create_provider(
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'], # Same keys (sorted)
|
||||
)
|
||||
|
||||
# Verify - returns existing UUID
|
||||
assert result == 'existing-uuid'
|
||||
|
||||
async def test_find_existing_provider_keys_order_mismatch(self):
|
||||
"""Returns existing provider when keys match but order differs."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
existing_provider = _create_mock_provider(
|
||||
provider_uuid='existing-uuid',
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key1', 'key2'],
|
||||
)
|
||||
|
||||
mock_result = _create_mock_result([existing_provider])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute with reversed key order
|
||||
result = await service.find_or_create_provider(
|
||||
requester='openai',
|
||||
base_url='https://api.openai.com',
|
||||
api_keys=['key2', 'key1'], # Different order, should still match
|
||||
)
|
||||
|
||||
# Verify - returns existing UUID (keys are sorted in comparison)
|
||||
assert result == 'existing-uuid'
|
||||
|
||||
async def test_create_new_provider_no_match(self):
|
||||
"""Creates new provider when no existing match."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = None # Will be set by uuid.uuid4()
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock no existing providers
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.find_or_create_provider(
|
||||
requester='new-requester',
|
||||
base_url='https://new.api.com',
|
||||
api_keys=['new-key'],
|
||||
)
|
||||
|
||||
# Verify - creates new provider with valid UUID format
|
||||
assert result is not None
|
||||
assert len(result) == 36 # UUID format
|
||||
# Verify provider was loaded to runtime
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
|
||||
async def test_create_provider_name_from_url_parse(self):
|
||||
"""Creates provider with name parsed from URL."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {}
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.provider_entity = Mock()
|
||||
runtime_provider.provider_entity.uuid = 'parsed-url-uuid'
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result_uuid = await service.find_or_create_provider(
|
||||
requester='custom',
|
||||
base_url='https://api.example.com/v1',
|
||||
api_keys=['key'],
|
||||
)
|
||||
|
||||
# Verify - name should be parsed from URL (api.example.com)
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
assert result_uuid is not None
|
||||
|
||||
|
||||
class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
|
||||
"""Tests for update_space_model_provider_api_keys method."""
|
||||
|
||||
async def test_update_space_provider_api_keys(self):
|
||||
"""Updates Space provider API keys."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.reload_provider = AsyncMock()
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_space_model_provider_api_keys('space-api-key')
|
||||
|
||||
# Verify - update and reload called for Space provider UUID
|
||||
ap.model_mgr.reload_provider.assert_called_once_with(
|
||||
'00000000-0000-0000-0000-000000000000'
|
||||
)
|
||||
|
||||
|
||||
class TestModelProviderServiceScanProviderModels:
|
||||
"""Tests for scan_provider_models method."""
|
||||
|
||||
async def test_scan_provider_not_found_raises_error(self):
|
||||
"""Raises ValueError when provider not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result([], first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='provider not found'):
|
||||
await service.scan_provider_models('nonexistent-uuid')
|
||||
|
||||
async def test_scan_provider_returns_models_list(self):
|
||||
"""Returns scanned models list."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='scan-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'scan-uuid',
|
||||
'name': 'Scan Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
# Mock runtime provider with scan capability
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
# Mock scan_models to return models
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
|
||||
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock existing model services
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.scan_provider_models('scan-uuid')
|
||||
|
||||
# Verify
|
||||
assert 'models' in result
|
||||
assert len(result['models']) == 2
|
||||
|
||||
async def test_scan_provider_filter_by_model_type(self):
|
||||
"""Returns filtered models by type."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='filter-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'filter-uuid',
|
||||
'name': 'Filter Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
|
||||
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute - filter for LLM only
|
||||
result = await service.scan_provider_models('filter-uuid', model_type='llm')
|
||||
|
||||
# Verify - only LLM models returned
|
||||
assert len(result['models']) == 1
|
||||
assert result['models'][0]['type'] == 'llm'
|
||||
|
||||
async def test_scan_provider_not_implemented_raises_error(self):
|
||||
"""Raises ValueError when scan not implemented."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='no-scan-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'no-scan-uuid',
|
||||
'name': 'No Scan Provider',
|
||||
'requester': 'custom',
|
||||
'base_url': 'https://custom.api.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
runtime_provider.requester.scan_models = AsyncMock(
|
||||
side_effect=NotImplementedError('scan not supported')
|
||||
)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='current provider does not support model scanning'):
|
||||
await service.scan_provider_models('no-scan-uuid')
|
||||
|
||||
async def test_scan_provider_marks_already_added_models(self):
|
||||
"""Marks models that are already added."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.llm_model_service = SimpleNamespace()
|
||||
ap.embedding_models_service = SimpleNamespace()
|
||||
|
||||
provider = _create_mock_provider(provider_uuid='already-added-uuid')
|
||||
|
||||
mock_result = _create_mock_result([], first_item=provider)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'already-added-uuid',
|
||||
'name': 'Already Added Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_provider = Mock()
|
||||
runtime_provider.requester = Mock()
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
|
||||
async def mock_scan_models(token):
|
||||
return {
|
||||
'models': [
|
||||
{'id': 'existing-model', 'name': 'Existing Model', 'type': 'llm'},
|
||||
{'id': 'new-model', 'name': 'New Model', 'type': 'llm'},
|
||||
],
|
||||
'debug': None,
|
||||
}
|
||||
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock existing LLM model
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
|
||||
return_value=[{'name': 'Existing Model'}]
|
||||
)
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.scan_provider_models('already-added-uuid')
|
||||
|
||||
# Verify - existing model marked as already_added
|
||||
existing_model = next(m for m in result['models'] if m['name'] == 'Existing Model')
|
||||
assert existing_model['already_added'] is True
|
||||
|
||||
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
|
||||
assert new_model['already_added'] is False
|
||||
778
tests/unit_tests/api/service/test_space_service.py
Normal file
778
tests/unit_tests/api/service/test_space_service.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""
|
||||
Unit tests for SpaceService.
|
||||
|
||||
Tests LangBot Space API interactions including:
|
||||
- OAuth URL generation
|
||||
- Token exchange and refresh
|
||||
- User info retrieval
|
||||
- Credits caching
|
||||
- Model listing
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/space.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from types import SimpleNamespace
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from langbot.pkg.api.http.service.space import SpaceService
|
||||
from langbot.pkg.entity.persistence.user import User
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_user(
|
||||
email: str = 'test@example.com',
|
||||
account_type: str = 'space',
|
||||
space_account_uuid: str = 'space-uuid-123',
|
||||
space_access_token: str = 'access_token_123',
|
||||
space_refresh_token: str = 'refresh_token_123',
|
||||
space_access_token_expires_at: datetime.datetime = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock User entity."""
|
||||
user = Mock(spec=User)
|
||||
user.user = email
|
||||
user.account_type = account_type
|
||||
user.space_account_uuid = space_account_uuid
|
||||
user.space_access_token = space_access_token
|
||||
user.space_refresh_token = space_refresh_token
|
||||
user.space_access_token_expires_at = space_access_token_expires_at
|
||||
return user
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestSpaceServiceGetOAuthAuthorizeUrl:
|
||||
"""Tests for get_oauth_authorize_url method."""
|
||||
|
||||
def test_get_oauth_authorize_url_basic(self):
|
||||
"""Returns OAuth URL with redirect_uri."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'space': {
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
}
|
||||
}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback')
|
||||
|
||||
# Verify
|
||||
assert 'redirect_uri=http://localhost/callback' in result
|
||||
assert 'https://space.langbot.app/auth/authorize' in result
|
||||
|
||||
def test_get_oauth_authorize_url_with_state(self):
|
||||
"""Returns OAuth URL with redirect_uri and state."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'space': {
|
||||
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
|
||||
}
|
||||
}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback', state='random_state')
|
||||
|
||||
# Verify
|
||||
assert 'redirect_uri=http://localhost/callback' in result
|
||||
assert 'state=random_state' in result
|
||||
|
||||
def test_get_oauth_authorize_url_default_config(self):
|
||||
"""Uses default OAuth URL when config not set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = service.get_oauth_authorize_url('http://localhost/callback')
|
||||
|
||||
# Verify - uses default URL
|
||||
assert 'https://space.langbot.app/auth/authorize' in result
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserByEmail:
|
||||
"""Tests for _get_user_by_email internal method."""
|
||||
|
||||
async def test_get_user_by_email_found(self):
|
||||
"""Returns user when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='found@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._get_user_by_email('found@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'found@example.com'
|
||||
|
||||
async def test_get_user_by_email_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._get_user_by_email('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpaceServiceEnsureValidToken:
|
||||
"""Tests for _ensure_valid_token internal method."""
|
||||
|
||||
async def test_ensure_valid_token_user_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_not_space_account(self):
|
||||
"""Returns None when user is not a space account."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='local@example.com', account_type='local')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('local@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_no_access_token(self):
|
||||
"""Returns None when user has no access token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(space_access_token=None)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_ensure_valid_token_valid_token(self):
|
||||
"""Returns valid access token when not expired."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Token expires in 1 hour (valid)
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result == 'valid_token'
|
||||
|
||||
async def test_ensure_valid_token_expired_no_refresh(self):
|
||||
"""Returns None when token expired and no refresh token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Token expired 1 hour ago
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='expired_token',
|
||||
space_refresh_token=None,
|
||||
space_access_token_expires_at=datetime.datetime.now() - datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service._ensure_valid_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpaceServiceGetCredits:
|
||||
"""Tests for get_credits method."""
|
||||
|
||||
async def test_get_credits_no_user(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_credits_returns_cached_value(self):
|
||||
"""Returns cached credits without API call."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate cache
|
||||
service._credits_cache = {'cached@example.com': (100, time.time())}
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('cached@example.com')
|
||||
|
||||
# Verify - returns cached value without API call
|
||||
assert result == 100
|
||||
|
||||
async def test_get_credits_cache_expired_refreshes(self):
|
||||
"""Refreshes expired cache."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate expired cache (70 seconds ago, past 60s TTL)
|
||||
service._credits_cache = {'test@example.com': (50, time.time() - 70)}
|
||||
|
||||
# Mock get_user_info to return new credits
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 200})
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - cache was refreshed
|
||||
assert result == 200
|
||||
assert service._credits_cache['test@example.com'][0] == 200
|
||||
|
||||
async def test_get_credits_force_refresh(self):
|
||||
"""Force refresh ignores cache."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate cache
|
||||
service._credits_cache = {'test@example.com': (100, time.time())}
|
||||
|
||||
# Mock get_user_info to return new credits
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 300})
|
||||
|
||||
# Execute with force_refresh=True
|
||||
result = await service.get_credits('test@example.com', force_refresh=True)
|
||||
|
||||
# Verify - fresh value returned
|
||||
assert result == 300
|
||||
|
||||
async def test_get_credits_returns_cached_on_exception(self):
|
||||
"""Returns cached fallback value when API fails."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Pre-populate expired cache - will try to refresh and fail
|
||||
service._credits_cache = {'test@example.com': (150, time.time() - 70)}
|
||||
|
||||
# Mock get_user_info to raise exception
|
||||
service.get_user_info = AsyncMock(side_effect=Exception('API Error'))
|
||||
|
||||
# Execute - should return cached fallback value (even though expired)
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - returns cached fallback value (150) because API failed
|
||||
assert result == 150
|
||||
|
||||
|
||||
class TestSpaceServiceRefreshToken:
|
||||
"""Tests for refresh_token method."""
|
||||
|
||||
async def test_refresh_token_success(self):
|
||||
"""Refreshes token successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
# Use async context manager mock
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.refresh_token('old_refresh_token')
|
||||
|
||||
# Verify
|
||||
assert result['access_token'] == 'new_access_token'
|
||||
|
||||
async def test_refresh_token_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 1,
|
||||
'msg': 'Invalid refresh token',
|
||||
})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to refresh token'):
|
||||
await service.refresh_token('invalid_refresh_token')
|
||||
|
||||
async def test_refresh_token_http_error(self):
|
||||
"""Raises ValueError on HTTP error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error status
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
mock_response.text = AsyncMock(return_value='Internal Server Error')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to refresh token'):
|
||||
await service.refresh_token('refresh_token')
|
||||
|
||||
|
||||
class TestSpaceServiceExchangeOAuthCode:
|
||||
"""Tests for exchange_oauth_code method."""
|
||||
|
||||
async def test_exchange_oauth_code_success(self):
|
||||
"""Exchanges OAuth code successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.exchange_oauth_code('auth_code')
|
||||
|
||||
# Verify
|
||||
assert result['access_token'] == 'new_access_token'
|
||||
|
||||
async def test_exchange_oauth_code_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Invalid code'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid code"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.post = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to exchange OAuth code'):
|
||||
await service.exchange_oauth_code('invalid_code')
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserInfoRaw:
|
||||
"""Tests for get_user_info_raw method."""
|
||||
|
||||
async def test_get_user_info_raw_success(self):
|
||||
"""Gets user info successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'email': 'test@example.com',
|
||||
'credits': 100,
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info_raw('access_token')
|
||||
|
||||
# Verify
|
||||
assert result['email'] == 'test@example.com'
|
||||
assert result['credits'] == 100
|
||||
|
||||
async def test_get_user_info_raw_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to get user info'):
|
||||
await service.get_user_info_raw('invalid_token')
|
||||
|
||||
|
||||
class TestSpaceServiceGetUserInfo:
|
||||
"""Tests for get_user_info method (with token validation)."""
|
||||
|
||||
async def test_get_user_info_no_token(self):
|
||||
"""Returns None when no valid token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_user_info_with_valid_token(self):
|
||||
"""Returns user info with valid token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock get_user_info_raw
|
||||
service.get_user_info_raw = AsyncMock(return_value={'email': 'test@example.com', 'credits': 100})
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_info('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert result['email'] == 'test@example.com'
|
||||
|
||||
|
||||
class TestSpaceServiceGetModels:
|
||||
"""Tests for get_models method."""
|
||||
|
||||
async def test_get_models_success(self):
|
||||
"""Gets models successfully."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with proper model data matching SpaceModel schema
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'models': [
|
||||
{
|
||||
'uuid': 'uuid-1',
|
||||
'model_id': 'model-1',
|
||||
'provider': 'provider-1',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
{
|
||||
'uuid': 'uuid-2',
|
||||
'model_id': 'model-2',
|
||||
'provider': 'provider-2',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await service.get_models()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
|
||||
async def test_get_models_api_error(self):
|
||||
"""Raises ValueError on API error."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.get = MagicMock(return_value=mock_response)
|
||||
mock_session.return_value = mock_session_obj
|
||||
|
||||
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Failed to get models'):
|
||||
await service.get_models()
|
||||
|
||||
|
||||
class TestSpaceServiceCreditsCache:
|
||||
"""Tests for credits cache behavior."""
|
||||
|
||||
def test_credits_cache_initialized(self):
|
||||
"""Verify _credits_cache is initialized as empty dict."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Verify
|
||||
assert hasattr(service, '_credits_cache')
|
||||
assert service._credits_cache == {}
|
||||
|
||||
async def test_credits_cache_updates_on_success(self):
|
||||
"""Cache updates when get_credits succeeds."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_user = _create_mock_user(
|
||||
space_access_token='valid_token',
|
||||
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = SpaceService(ap)
|
||||
|
||||
# Mock get_user_info
|
||||
service.get_user_info = AsyncMock(return_value={'credits': 500})
|
||||
|
||||
# Execute
|
||||
result = await service.get_credits('test@example.com')
|
||||
|
||||
# Verify - cache updated
|
||||
assert result == 500
|
||||
assert 'test@example.com' in service._credits_cache
|
||||
assert service._credits_cache['test@example.com'][0] == 500
|
||||
608
tests/unit_tests/api/service/test_user_service.py
Normal file
608
tests/unit_tests/api/service/test_user_service.py
Normal file
@@ -0,0 +1,608 @@
|
||||
"""
|
||||
Unit tests for UserService.
|
||||
|
||||
Tests user management operations including:
|
||||
- User initialization check
|
||||
- Local user creation and authentication
|
||||
- JWT token generation and verification
|
||||
- Password management (reset, change, set)
|
||||
- Space account management
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/user.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.user import UserService
|
||||
from langbot.pkg.entity.persistence.user import User
|
||||
from langbot.pkg.entity.errors.account import AccountEmailMismatchError
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_user(
|
||||
email: str = 'test@example.com',
|
||||
password: str = 'hashed_password',
|
||||
account_type: str = 'local',
|
||||
space_account_uuid: str = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock User entity."""
|
||||
user = Mock(spec=User)
|
||||
user.user = email
|
||||
user.password = password
|
||||
user.account_type = account_type
|
||||
user.space_account_uuid = space_account_uuid
|
||||
return user
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestUserServiceIsInitialized:
|
||||
"""Tests for is_initialized method."""
|
||||
|
||||
async def test_is_initialized_returns_true_when_users_exist(self):
|
||||
"""Returns True when at least one user exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user()
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
async def test_is_initialized_returns_false_when_no_users(self):
|
||||
"""Returns False when no users exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
async def test_is_initialized_returns_false_on_none_result(self):
|
||||
"""Returns False when result is None."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = Mock()
|
||||
mock_result.all = Mock(return_value=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.is_initialized()
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestUserServiceGetUserByEmail:
|
||||
"""Tests for get_user_by_email method."""
|
||||
|
||||
async def test_get_user_by_email_found(self):
|
||||
"""Returns user when found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='found@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('found@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'found@example.com'
|
||||
|
||||
async def test_get_user_by_email_not_found(self):
|
||||
"""Returns None when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('notfound@example.com')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_user_by_email_empty_string(self):
|
||||
"""Handles empty email string."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_email('')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceGetUserBySpaceAccountUuid:
|
||||
"""Tests for get_user_by_space_account_uuid method."""
|
||||
|
||||
async def test_get_user_by_space_uuid_found(self):
|
||||
"""Returns user when Space UUID found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='space-uuid-123',
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_space_account_uuid('space-uuid-123')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.space_account_uuid == 'space-uuid-123'
|
||||
|
||||
async def test_get_user_by_space_uuid_not_found(self):
|
||||
"""Returns None when Space UUID not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_user_by_space_account_uuid('nonexistent-uuid')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceAuthenticate:
|
||||
"""Tests for authenticate method."""
|
||||
|
||||
async def test_authenticate_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='用户不存在'):
|
||||
await service.authenticate('nonexistent@example.com', 'password')
|
||||
|
||||
async def test_authenticate_space_user_without_password_raises_error(self):
|
||||
"""Raises ValueError for Space user without local password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
# Space user has empty password
|
||||
mock_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
password='', # Empty password for Space user
|
||||
account_type='space',
|
||||
)
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='请使用 Space 账户登录'):
|
||||
await service.authenticate('space@example.com', 'password')
|
||||
|
||||
|
||||
class TestUserServiceGenerateJwtToken:
|
||||
"""Tests for generate_jwt_token method."""
|
||||
|
||||
async def test_generate_jwt_token_returns_valid_token(self):
|
||||
"""Generates valid JWT token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
token = await service.generate_jwt_token('test@example.com')
|
||||
|
||||
# Verify - JWT format (base64 encoded parts)
|
||||
assert token is not None
|
||||
assert len(token) > 0
|
||||
parts = token.split('.')
|
||||
assert len(parts) == 3 # JWT has 3 parts
|
||||
|
||||
async def test_generate_jwt_token_custom_expire(self):
|
||||
"""Generates token with custom expiry."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 7200}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
token = await service.generate_jwt_token('test@example.com')
|
||||
|
||||
# Verify
|
||||
assert token is not None
|
||||
|
||||
|
||||
class TestUserServiceVerifyJwtToken:
|
||||
"""Tests for verify_jwt_token method."""
|
||||
|
||||
async def test_verify_jwt_token_valid(self):
|
||||
"""Verifies valid JWT token and returns user email."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# First generate a valid token
|
||||
token = await service.generate_jwt_token('verify@example.com')
|
||||
|
||||
# Execute
|
||||
user_email = await service.verify_jwt_token(token)
|
||||
|
||||
# Verify
|
||||
assert user_email == 'verify@example.com'
|
||||
|
||||
async def test_verify_jwt_token_invalid_raises_error(self):
|
||||
"""Raises error for invalid JWT token."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute & Verify - invalid token should raise JWT error
|
||||
with pytest.raises(Exception): # jwt.DecodeError or similar
|
||||
await service.verify_jwt_token('invalid.token.here')
|
||||
|
||||
|
||||
class TestUserServiceResetPassword:
|
||||
"""Tests for reset_password method."""
|
||||
|
||||
async def test_reset_password_updates_password(self):
|
||||
"""Updates user password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
await service.reset_password('test@example.com', 'new_password')
|
||||
|
||||
# Verify - execute_async was called with update
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestUserServiceChangePassword:
|
||||
"""Tests for change_password method."""
|
||||
|
||||
async def test_change_password_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock get_user_by_email to return None
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='User not found'):
|
||||
await service.change_password('nonexistent@example.com', 'current', 'new')
|
||||
|
||||
async def test_change_password_no_local_password_raises_error(self):
|
||||
"""Raises ValueError when user has no local password set."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock user without password
|
||||
mock_user = _create_mock_user(email='nopass@example.com', password=None)
|
||||
service.get_user_by_email = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='No local password set'):
|
||||
await service.change_password('nopass@example.com', 'current', 'new')
|
||||
|
||||
|
||||
class TestUserServiceGetFirstUser:
|
||||
"""Tests for get_first_user method."""
|
||||
|
||||
async def test_get_first_user_found(self):
|
||||
"""Returns first user when exists."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_user = _create_mock_user(email='first@example.com')
|
||||
mock_result = _create_mock_result([mock_user])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_first_user()
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.user == 'first@example.com'
|
||||
|
||||
async def test_get_first_user_not_found(self):
|
||||
"""Returns None when no users exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_first_user()
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserServiceSetPassword:
|
||||
"""Tests for set_password method."""
|
||||
|
||||
async def test_set_password_user_not_found_raises_error(self):
|
||||
"""Raises ValueError when user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock get_user_by_email to return None
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='User not found'):
|
||||
await service.set_password('nonexistent@example.com', 'new_password')
|
||||
|
||||
async def test_set_password_with_existing_password_requires_current(self):
|
||||
"""Requires current password when user has existing password."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock user with existing password
|
||||
mock_user = _create_mock_user(email='haspass@example.com', password='hashed_old_password')
|
||||
service.get_user_by_email = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Execute & Verify - should raise when no current_password provided
|
||||
with pytest.raises(ValueError, match='Current password is required'):
|
||||
await service.set_password('haspass@example.com', 'new_password')
|
||||
|
||||
|
||||
class TestUserServiceCreateOrUpdateSpaceUser:
|
||||
"""Tests for create_or_update_space_user method."""
|
||||
|
||||
async def test_create_or_update_existing_space_user(self):
|
||||
"""Updates existing Space user tokens."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock existing Space user
|
||||
existing_user = _create_mock_user(
|
||||
email='space@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='existing-space-uuid',
|
||||
)
|
||||
service.get_user_by_space_account_uuid = AsyncMock(return_value=existing_user)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=True)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute
|
||||
updated_user = await service.create_or_update_space_user(
|
||||
space_account_uuid='existing-space-uuid',
|
||||
email='space@example.com',
|
||||
access_token='new_access_token',
|
||||
refresh_token='new_refresh_token',
|
||||
api_key='new_api_key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
# Verify - update was called and user returned
|
||||
ap.persistence_mgr.execute_async.assert_called()
|
||||
assert updated_user.space_account_uuid == 'existing-space-uuid'
|
||||
|
||||
async def test_create_or_update_new_space_user_first_init(self):
|
||||
"""Creates new Space user on first initialization."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock new user to be returned after creation
|
||||
new_user = _create_mock_user(
|
||||
email='newspace@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='new-space-uuid',
|
||||
)
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1: # First check for existing user
|
||||
return None
|
||||
return new_user # After insert, return the new user
|
||||
|
||||
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=False) # Not initialized
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute
|
||||
result = await service.create_or_update_space_user(
|
||||
space_account_uuid='new-space-uuid',
|
||||
email='newspace@example.com',
|
||||
access_token='access_token',
|
||||
refresh_token='refresh_token',
|
||||
api_key='api_key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.space_account_uuid == 'new-space-uuid'
|
||||
|
||||
async def test_create_or_update_space_user_already_initialized_raises_error(self):
|
||||
"""Raises AccountEmailMismatchError when system already initialized and user not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Mock system already initialized, no matching users
|
||||
service.get_user_by_space_account_uuid = AsyncMock(return_value=None)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=True) # Already initialized
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(AccountEmailMismatchError):
|
||||
await service.create_or_update_space_user(
|
||||
space_account_uuid='unknown-space-uuid',
|
||||
email='unknown@example.com',
|
||||
access_token='token',
|
||||
refresh_token='refresh',
|
||||
api_key='key',
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
async def test_create_or_update_space_user_no_expiry(self):
|
||||
"""Creates Space user without token expiry."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.provider_service = SimpleNamespace()
|
||||
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
new_user = _create_mock_user(
|
||||
email='noexpiry@example.com',
|
||||
account_type='space',
|
||||
space_account_uuid='noexpiry-uuid',
|
||||
)
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1: # First check for existing user
|
||||
return None
|
||||
return new_user # After insert, return the new user
|
||||
|
||||
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
|
||||
service.get_user_by_email = AsyncMock(return_value=None)
|
||||
service.is_initialized = AsyncMock(return_value=False)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
# Execute with expires_in=0 (no expiry)
|
||||
result = await service.create_or_update_space_user(
|
||||
space_account_uuid='noexpiry-uuid',
|
||||
email='noexpiry@example.com',
|
||||
access_token='token',
|
||||
refresh_token='refresh',
|
||||
api_key='key',
|
||||
expires_in=0, # No expiry
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.space_account_uuid == 'noexpiry-uuid'
|
||||
|
||||
|
||||
class TestUserServiceCreateUserLock:
|
||||
"""Tests for create_user_lock attribute."""
|
||||
|
||||
def test_create_user_lock_initialized(self):
|
||||
"""Verify create_user_lock is initialized as asyncio.Lock."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
|
||||
service = UserService(ap)
|
||||
|
||||
# Verify lock exists
|
||||
assert hasattr(service, '_create_user_lock')
|
||||
assert service._create_user_lock is not None
|
||||
506
tests/unit_tests/api/service/test_webhook_service.py
Normal file
506
tests/unit_tests/api/service/test_webhook_service.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Unit tests for WebhookService.
|
||||
|
||||
Tests webhook CRUD operations including:
|
||||
- Webhook listing
|
||||
- Webhook creation
|
||||
- Webhook retrieval by ID
|
||||
- Webhook updates
|
||||
- Webhook deletion
|
||||
- Enabled webhooks filtering
|
||||
|
||||
Source: src/langbot/pkg/api/http/service/webhook.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langbot.pkg.api.http.service.webhook import WebhookService
|
||||
from langbot.pkg.entity.persistence.webhook import Webhook
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _create_mock_webhook(
|
||||
webhook_id: int = 1,
|
||||
name: str = 'Test Webhook',
|
||||
url: str = 'http://example.com/webhook',
|
||||
description: str = 'Test Description',
|
||||
enabled: bool = True,
|
||||
) -> Mock:
|
||||
"""Helper to create mock Webhook entity."""
|
||||
webhook = Mock(spec=Webhook)
|
||||
webhook.id = webhook_id
|
||||
webhook.name = name
|
||||
webhook.url = url
|
||||
webhook.description = description
|
||||
webhook.enabled = enabled
|
||||
return webhook
|
||||
|
||||
|
||||
def _create_mock_result(items: list = None, first_item=None):
|
||||
"""Create mock result object for persistence queries."""
|
||||
result = Mock()
|
||||
result.all = Mock(return_value=items or [])
|
||||
result.first = Mock(return_value=first_item)
|
||||
return result
|
||||
|
||||
|
||||
class TestWebhookServiceGetWebhooks:
|
||||
"""Tests for get_webhooks method."""
|
||||
|
||||
async def test_get_webhooks_empty_list(self):
|
||||
"""Returns empty list when no webhooks exist."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'url': entity.url,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhooks()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_webhooks_returns_serialized_list(self):
|
||||
"""Returns serialized list of webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
webhook1 = _create_mock_webhook(webhook_id=1, name='Webhook 1')
|
||||
webhook2 = _create_mock_webhook(webhook_id=2, name='Webhook 2')
|
||||
|
||||
mock_result = _create_mock_result([webhook1, webhook2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'url': entity.url,
|
||||
'description': entity.description,
|
||||
'enabled': entity.enabled,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhooks()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0]['name'] == 'Webhook 1'
|
||||
assert result[1]['name'] == 'Webhook 2'
|
||||
|
||||
|
||||
class TestWebhookServiceCreateWebhook:
|
||||
"""Tests for create_webhook method."""
|
||||
|
||||
async def test_create_webhook_full_params(self):
|
||||
"""Creates webhook with all parameters."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Mock insert result
|
||||
insert_result = Mock()
|
||||
|
||||
# Mock select result for retrieving created webhook
|
||||
created_webhook = _create_mock_webhook(
|
||||
webhook_id=1,
|
||||
name='New Webhook',
|
||||
url='http://new.example.com/webhook',
|
||||
description='New Description',
|
||||
enabled=True,
|
||||
)
|
||||
select_result = _create_mock_result(first_item=created_webhook)
|
||||
|
||||
# execute_async returns different results
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return insert_result # Insert
|
||||
return select_result # Select
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'New Webhook',
|
||||
'url': 'http://new.example.com/webhook',
|
||||
'description': 'New Description',
|
||||
'enabled': True,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_webhook(
|
||||
name='New Webhook',
|
||||
url='http://new.example.com/webhook',
|
||||
description='New Description',
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result['name'] == 'New Webhook'
|
||||
assert result['url'] == 'http://new.example.com/webhook'
|
||||
assert result['description'] == 'New Description'
|
||||
assert result['enabled'] is True
|
||||
|
||||
async def test_create_webhook_defaults(self):
|
||||
"""Creates webhook with default description and enabled."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_webhook = _create_mock_webhook(
|
||||
webhook_id=1,
|
||||
name='Minimal Webhook',
|
||||
url='http://minimal.example.com',
|
||||
description='', # Default
|
||||
enabled=True, # Default
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return Mock() # Insert
|
||||
return _create_mock_result(first_item=created_webhook)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Minimal Webhook',
|
||||
'url': 'http://minimal.example.com',
|
||||
'description': '',
|
||||
'enabled': True,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - only name and url required
|
||||
result = await service.create_webhook(name='Minimal Webhook', url='http://minimal.example.com')
|
||||
|
||||
# Verify defaults
|
||||
assert result['description'] == ''
|
||||
assert result['enabled'] is True
|
||||
|
||||
async def test_create_webhook_disabled(self):
|
||||
"""Creates webhook with enabled=False."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return Mock()
|
||||
return _create_mock_result(first_item=created_webhook)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'id': 1, 'enabled': False}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.create_webhook(name='Disabled', url='http://disabled.com', enabled=False)
|
||||
|
||||
# Verify
|
||||
assert result['enabled'] is False
|
||||
|
||||
|
||||
class TestWebhookServiceGetWebhook:
|
||||
"""Tests for get_webhook method."""
|
||||
|
||||
async def test_get_webhook_by_id_found(self):
|
||||
"""Returns webhook when found by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
webhook = _create_mock_webhook(webhook_id=1, name='Found Webhook')
|
||||
mock_result = _create_mock_result(first_item=webhook)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'id': 1,
|
||||
'name': 'Found Webhook',
|
||||
'url': 'http://example.com/webhook',
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(1)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['id'] == 1
|
||||
assert result['name'] == 'Found Webhook'
|
||||
|
||||
async def test_get_webhook_by_id_not_found(self):
|
||||
"""Returns None when webhook not found."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(999)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
async def test_get_webhook_by_id_zero(self):
|
||||
"""Handles ID=0 (edge case) correctly."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
mock_result = _create_mock_result(first_item=None)
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_webhook(0)
|
||||
|
||||
# Verify - should return None (no webhook with ID 0)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestWebhookServiceUpdateWebhook:
|
||||
"""Tests for update_webhook method."""
|
||||
|
||||
async def test_update_webhook_name_only(self):
|
||||
"""Updates only the name field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, name='Updated Name')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_url_only(self):
|
||||
"""Updates only the url field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, url='http://updated.example.com')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_description_only(self):
|
||||
"""Updates only the description field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, description='Updated description')
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_enabled_only(self):
|
||||
"""Updates only the enabled field."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(1, enabled=False)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_all_fields(self):
|
||||
"""Updates all fields at once."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_webhook(
|
||||
1,
|
||||
name='All Updated',
|
||||
url='http://all.updated.com',
|
||||
description='All updated description',
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_update_webhook_no_fields(self):
|
||||
"""Does nothing when no fields provided."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - no update parameters
|
||||
await service.update_webhook(1)
|
||||
|
||||
# Verify - no execute call since no update_data
|
||||
ap.persistence_mgr.execute_async.assert_not_called()
|
||||
|
||||
|
||||
class TestWebhookServiceDeleteWebhook:
|
||||
"""Tests for delete_webhook method."""
|
||||
|
||||
async def test_delete_webhook_by_id(self):
|
||||
"""Deletes webhook by ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
await service.delete_webhook(1)
|
||||
|
||||
# Verify
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
async def test_delete_webhook_nonexistent_id(self):
|
||||
"""Delete operation completes even for nonexistent ID."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute - should not raise
|
||||
await service.delete_webhook(999)
|
||||
|
||||
# Verify - still called
|
||||
ap.persistence_mgr.execute_async.assert_called_once()
|
||||
|
||||
|
||||
class TestWebhookServiceGetEnabledWebhooks:
|
||||
"""Tests for get_enabled_webhooks method."""
|
||||
|
||||
async def test_get_enabled_webhooks_empty(self):
|
||||
"""Returns empty list when no enabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify
|
||||
assert result == []
|
||||
|
||||
async def test_get_enabled_webhooks_filters_enabled(self):
|
||||
"""Returns only enabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# All returned webhooks should be enabled (SQL filter)
|
||||
webhook1 = _create_mock_webhook(webhook_id=1, name='Enabled 1', enabled=True)
|
||||
webhook2 = _create_mock_webhook(webhook_id=2, name='Enabled 2', enabled=True)
|
||||
|
||||
mock_result = _create_mock_result([webhook1, webhook2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'id': entity.id,
|
||||
'name': entity.name,
|
||||
'enabled': entity.enabled,
|
||||
}
|
||||
)
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert all(w['enabled'] for w in result)
|
||||
|
||||
async def test_get_enabled_webhooks_filters_disabled(self):
|
||||
"""Does not return disabled webhooks."""
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
# Empty result because query filters on enabled=True
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={})
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
# Execute
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify - should be empty (SQL would filter disabled)
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user