Feat/test build (#2174)

* fix(ci): update unit-test workflow paths to match current source layout

Replace stale pkg/** filter with src/langbot/** and add uv.lock.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* docs(tests): update README to reflect current test layout

- Fix stale paths: tests/pipeline → tests/unit_tests/pipeline
- Update CI Python versions: 3.11, 3.12, 3.13
- Add test directory structure for box, config, platform, plugin, provider, storage
- Document pytest markers and uv commands
- Mention planned E2E tests

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add shared test factories package

Create tests/factories/ with reusable test factories:
- FakeApp: mock application with all dependencies
- Message chains: text_chain, mention_chain, image_chain
- Query factories: text_query, group_text_query, command_query, etc.

No test changes - maintains backward compatibility.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add fake provider factory

Add tests/factories/provider.py with:
- FakeProvider: deterministic fake LLM provider
- Error simulation: timeout, auth, rate-limit, malformed
- Request capture for assertions
- fake_model: mock model with attached provider

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add fake platform factory

Add tests/factories/platform.py with:
- FakePlatform: simulated platform adapter
- Inbound message construction: friend/group/image
- Mention-bot flag simulation
- Outbound message capture for assertions
- Streaming output support simulation
- Send failure simulation

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add comprehensive message/query factories

Extend tests/factories/message.py with:
- file_query: file attachment query
- unsupported_query: unknown message segment
- voice_query: audio/voice query
- at_all_query: group @All mention
- query_with_session: query with session object
- query_with_config: query with custom pipeline config

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add fake message flow smoke test

Create tests/smoke/test_fake_message_flow.py:
- TestFakeMessageFlow: factory verification tests
- TestMessageFlowIntegration: minimal flow smoke test
- Tests FakeApp, FakeProvider, FakePlatform, query factories
- Verifies LANGBOT_FAKE_PONG marker response
- Captures outbound messages for assertions

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add developer test-quick command

Add scripts/test-quick.sh and Makefile with:
- test-quick: runs ruff check + unit tests + smoke tests
- No real provider keys or platform accounts required
- Suitable for local branch self-test

Update tests/README.md:
- Document test-quick command
- Document test factories package
- Add smoke tests and factories directory structure

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(test): make test-quick reliable as developer gate

Fixes for D-001验收问题:
1. test-quick.sh: use set -euo pipefail, uv run ruff, no tail pipe
2. Remove unused imports in factories (app.py, platform.py, provider.py)
3. Fix unused variable in smoke test
4. Add noqa: E402 to test_n8nsvapi.py lazy imports
5. Update smoke test docs: "minimal fake flow" not full pipeline

Now test-quick is a reliable gate: lint failures exit 1, test failures propagate.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(unit): add preproc and taskmgr unit tests

U-001: Pipeline Preprocessor tests
- Normal text message processing
- Empty message handling
- Image segment with/without vision model
- Model selection and fallback
- Variable extraction

U-004: Core Task Manager tests (pattern-based)
- Task creation and tracking patterns
- Task cancellation patterns
- Scope-based cancellation
- Task type filtering
- Pruning completed tasks
- Wait all tasks

Taskmgr tests use pattern-based approach to avoid circular import
in source code (taskmgr → app → http_controller → migration → taskmgr).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(unit): add config loader unit tests

U-005: Config Loader tests
- Valid YAML config loading
- Valid JSON config loading
- Invalid YAML/JSON error behavior
- Missing config file creation from template
- Template completion for missing keys
- ConfigManager load/dump operations
- Exists check for both YAML and JSON

All tests use tmp_path fixture, no real project config.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(unit): add chat and command handler pattern tests

U-002: Chat Handler tests (pattern-based)
- Normal message event emission pattern
- prevent_default handling
- User message alteration pattern
- Runner selection pattern
- Streaming/non-streaming response patterns
- Exception handling modes (show-error, show-hint, hide)
- Message history update pattern
- Telemetry payload pattern

U-003: Command Handler tests (pattern-based)
- Command parsing and text extraction
- Event creation pattern
- Privilege/admin check pattern
- Command result handling (text, error, image)
- prevent_default handling
- String truncation helper

Uses pattern-based testing to avoid circular import issues in source code.
Direct imports of handler modules trigger circular import chain.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* style: fix unused imports after ruff auto-fix

Remove unused imports in test files:
- test_config_loader.py: remove unused os
- test_taskmgr.py: remove unused Mock
- test_preproc.py: remove unused unsupported_query, image_chain

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(unit): improve taskmgr tests to test real classes

U-004 improved: Tests now import and test actual classes:
- TaskContext: new(), trace(), to_dict(), placeholder()
- TaskWrapper: task creation, context, exception/result capture, cancel, to_dict
- AsyncTaskManager: create_task, create_user_task, cancel_task, cancel_by_scope
- Task pruning behavior

Uses pre-mocking technique:
- Mock langbot.pkg.core.app before import (breaks circular chain)
- Mock langbot.pkg.core.entities with proper Enum

All 24 tests now test real class behavior, not patterns.
taskmgr.py coverage should improve significantly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* refactor(test): consolidate FakeApp and add sys.modules isolation utility

- Extract tests/utils/import_isolation.py with isolated_sys_modules context manager
- Extend tests/factories/app.py FakeApp with handler-specific attributes
- Refactor test_chat_handler.py to use centralized FakeApp and cached imports
- Refactor test_command_handler.py with mock_execute_factory fixture
- Refactor test_smoke.py to move import-time sys.modules manipulation into fixture
- Add SQLite migration integration tests (G-002)
- Add HTTP API smoke integration tests (G-005)
- Update CI workflow to call pytest for SQLite migrations (G-004)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add developer quality gate consolidation (G-007)

- Add scripts/test-integration-fast.sh for fast integration tests
- Add scripts/test-coverage.sh with 12% baseline threshold
- Update Makefile with test-integration-fast, test-coverage, test-all-local
- Update CI workflow with integration and coverage jobs
- Add smoke marker to pytest.ini
- Update tests/README.md with quality gate layers documentation
- Add tests/integration/pipeline/ for pipeline stage-chain tests

Quality gate layers:
- Quick: ruff + unit + smoke (~2 min)
- Fast Integration: SQLite/API/Pipeline (~3 min)
- Coverage: 12% threshold gate (~8 min)
- Full Local: all three combined

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): add PostgreSQL migration slow integration tests (G-003)

- Add tests/integration/persistence/test_migrations_postgres.py
- All tests marked with @pytest.mark.slow
- Tests skip when TEST_POSTGRES_URL is not set (no local PostgreSQL)
- Database isolation via clean_tables and clean_alembic_version fixtures
- Update CI workflow to use pytest instead of inline Python script
- Remove TODO(G-003) comment
- Update tests/README.md with PostgreSQL test documentation

Covered scenarios:
- Baseline stamp sets revision
- Upgrade from baseline to head
- Upgrade idempotent
- Get current on unstamped DB returns None

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat(test): Phase 1.5 coverage expansion - COV-001 to COV-013

Coverage baseline raised from 13.65% to 26% (+12.35%)
Gate raised from 12% to 18%

Tasks completed:
- COV-001: Command system unit tests (100% coverage)
- COV-002: API service unit tests batch 1 (user/apikey/model/provider)
- COV-003: Provider model manager unit tests
- COV-004: Pipeline remaining stage tests (aggregator/cntfilter/longtext/msgtrun)
- COV-005: Storage and utils coverage pass
- COV-006: Gate ratchet 12%→15%
- COV-007: Gate ratchet 15%→18%
- COV-008: API service batch 2 (bot/pipeline/webhook/space/maintenance/mcp)
- COV-009: Blocked - API controller circular import issue documented
- COV-010: Plugin runtime unit tests (+0.08%)
- COV-011: RAG and vector unit tests (+0.68%)
- COV-012: Core boot and migration unit tests
- COV-013: Provider requester logic unit tests (+0.62%)

Key additions:
- tests/utils/import_isolation.py: sys.modules isolation for circular imports
- Provider requester mock tests: proved HTTP-dependent code can be tested locally
- Vector filter utilities: 100% coverage on pure functions
- API services: fake persistence pattern for unit testing

Blocked issue COV-009 documented in langbot-test-plan/1.5/issues/

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(phase1): add unit tests for telemetry, plugin, rag, persistence

Add initial unit tests for Phase 1 of test coverage improvement:
- telemetry: test initialization, payload sanitization, early returns (14.3% → 62.9%)
- plugin: test _parse_plugin_id static method
- rag: test _to_i18n_name static method
- persistence: test serialize_model with datetime handling

Overall core coverage: 41.9% → 42.2%

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(phase2): add unit tests for core, persistence, plugin, utils

- Add test_handler_helpers.py for plugin handler helpers (7 tests)
- Add test_mgr_methods.py for persistence manager (5 tests)
- Add test_app_config_validation.py for core app config (12 tests)
- Add test_knowledge_service.py for API knowledge service (22 tests)
- Add test_kbmgr.py for RAG knowledge base manager (39 tests)
- Add test_survey_manager.py for survey manager (22 tests)
- Add test_connector_methods.py for plugin connector (24 tests)
- Add test_funcschema.py for utils function schema (9 tests)
- Add test_platform.py for utils platform detection (7 tests)
- Add test_extract_deps.py for plugin deps extraction (7 tests)
- Add test_database_decorator.py for persistence decorator (7 tests)
- Add test_load_config.py for core config loading (19 tests)
- Add COVERAGE_EXCLUSIONS.md documenting external adapter exclusions
- Fix test_chat_session_limit.py path for portability

Coverage: core 28% → 30%, persistence 24% → 24.4%, plugin 27% → 28%
Total: 1082 tests passed, core module coverage 45.5%

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(integration): add API controller integration tests

- Add test_pipelines.py (10 tests) covering pipelines CRUD operations
  - GET/POST/PUT/DELETE on /api/v1/pipelines
  - Extensions endpoint
  - Metadata endpoint
  - Coverage: pipelines controller 27% → 80%

- Add test_providers.py (10 tests) covering provider/model management
  - Provider CRUD with model counts
  - LLM model CRUD
  - Coverage: providers controller 23% → 81%, models 29% → 45%

Tests use Quart TestClient with mocked services for real HTTP behavior
without external dependencies.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(integration): add knowledge, bots, and model endpoints tests

- Add test_knowledge.py (10 tests) covering knowledge base management
  - CRUD operations on /api/v1/knowledge/bases
  - Files management endpoints
  - Retrieve endpoint with validation
  - Coverage: knowledge/base.py 26% → 91%

- Add test_bots.py (9 tests) covering bot management
  - CRUD operations on /api/v1/platform/bots
  - Logs endpoint
  - Send message endpoint with validation
  - Coverage: platform/bots.py 24% → 87%

- Extend test_providers.py (+4 tests) for embedding/rerank models
  - Embedding models CRUD
  - Rerank models CRUD
  - Coverage: provider/models.py 29% → 60%

Total integration tests: 53 (smoke 12 + pipelines 10 + providers 14 + knowledge 10 + bots 9)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(integration): add embed and monitoring endpoint tests

Add integration tests for embed widget and monitoring API endpoints:
- test_embed.py: 15 tests for widget.js, logo, turnstile, messages, reset, feedback
- test_monitoring.py: 15 tests for overview, messages, llm-calls, sessions, errors, export

Coverage improvements:
- embed.py: 17% → 56%
- monitoring.py: 17% → 93%

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(e2e): add minimal startup E2E tests

Add E2E tests for LangBot startup flow:
- tests/e2e/utils/config_factory.py: minimal config generation
- tests/e2e/utils/process_manager.py: LangBot subprocess management
- tests/e2e/conftest.py: E2E fixtures (session-scoped process)
- tests/e2e/test_startup.py: 12 tests for startup verification

Tests verify:
- boot.py + stages execution
- database initialization (SQLite)
- API availability
- migrations applied

Uses embedded databases (SQLite, Chroma) - no external dependencies.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test(quality): fix fake tests and add missing coverage

P0 fixes:
- telemetry: rewrite fake tests with real behavior verification (25 tests)
- config: delete copied-source tests, use proper imports (2 deleted)
- persistence: fix try-except pass to verify specific errors

P1 fixes:
- pipeline: add real FixedWindowAlgo tests instead of mocks (12 tests)
- provider: add SessionManager and ToolManager tests (25 tests)
- storage: add S3StorageProvider tests with moto mock (16 tests)
- plugin: add handler action tests for setting inheritance (15 tests)
- rag: add file storage and ZIP processing tests (21 tests)
- vector: add VDB filter conversion tests (30 tests)

P2 fixes:
- pipeline/msgtrun: strengthen assertions for exact message count
- api: add response structure validation in integration tests

New test files:
- provider/test_session_manager.py
- provider/test_tool_manager.py
- storage/test_s3storage.py
- plugin/test_handler_actions.py
- rag/test_file_storage.py
- vector/test_vdb_filter_conversion.py

Source code bugs documented:
- provider: TokenManager.next_token() ZeroDivisionError
- telemetry: send_tasks class variable shared state
- command: empty command IndexError, unused parameters
- utils: funcschema KeyError
- entity: vector.py independent declarative_base

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* docs(test): update coverage stats and test structure

- Update coverage from 22% to 30%
- Add new test files to structure:
  - provider: session_manager, tool_manager
  - storage: s3storage
  - plugin: handler_actions
  - rag: file_storage
  - vector: vdb_filter_conversion
  - telemetry: rewritten tests
- Update module coverage percentages

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* test: add 105 new unit tests for untested core functionality

Add comprehensive tests for B-class issues (core functionality untested):

Pipeline:
- test_pool.py: QueryPool ID generation, caching, async context (12 tests)
- test_ratelimit.py: Fixed timing-sensitive test tolerance
- test_pipelinemgr.py: Use real Pydantic StageProcessResult instead of Mock

Utils:
- test_version.py: Version comparison functions (20 tests)
- test_logcache.py: Log page management and retrieval (18 tests)
- test_httpclient.py: HTTP session pool management (10 tests)
- test_proxy.py: Proxy configuration from env and config (10 tests)
- test_image.py: URL parsing and base64 extraction (12 tests)
- test_pkgmgr.py: Pip command generation (8 tests)

Discover:
- test_engine.py: I18nString, Metadata, Component manifest (15 tests)

Test count: 1193 → 1298 (+105 tests)

Note: Some B-class issues cannot be tested due to circular import bugs
filed as GitHub issues #2175 (pipeline) and #2176 (persistence).

* test: tighten phase 1 coverage contracts

* test: align ci integration isolation

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
huanghuoguoguo
2026-05-16 12:05:54 +08:00
committed by GitHub
parent 4a4c0921a4
commit 17bbc8bf10
130 changed files with 32711 additions and 889 deletions
+295
View File
@@ -0,0 +1,295 @@
"""
Test fixtures for provider/modelmgr tests.
Provides fake persistence, mock requester registry, and test utilities
without calling real LLM APIs or network requests.
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.provider.modelmgr import requester
from langbot.pkg.provider.modelmgr import token
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
from langbot.pkg.entity.persistence import model as persistence_model
from langbot.pkg.discover import engine as discover_engine
class FakeProviderAPIRequester(requester.ProviderAPIRequester):
"""Fake requester for testing that does not make real API calls."""
name = 'fake-requester'
default_config = {'base_url': 'https://fake-api.example.com', 'timeout': 30}
def __init__(self, ap, config: dict):
super().__init__(ap, config)
self._invoke_count = 0
self._last_messages = None
self._last_model = None
async def invoke_llm(
self,
query,
model: requester.RuntimeLLMModel,
messages: list,
funcs=None,
extra_args={},
remove_think=False,
):
"""Return a fake message response."""
self._invoke_count += 1
self._last_messages = messages
self._last_model = model
# Import the message entity for response
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Fake LLM response')],
)
async def invoke_llm_stream(
self,
query,
model: requester.RuntimeLLMModel,
messages: list,
funcs=None,
extra_args={},
remove_think=False,
):
"""Yield fake message chunks."""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
yield provider_message.MessageChunk(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Fake stream chunk')],
)
async def invoke_embedding(self, model, input_text: list, extra_args={}):
"""Return fake embedding vectors."""
return [[0.1, 0.2, 0.3] for _ in input_text]
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results."""
return [{'index': i, 'relevance_score': 0.9 - i * 0.1} for i in range(len(documents))]
class AnotherFakeRequester(requester.ProviderAPIRequester):
"""Another fake requester for multi-requester tests."""
name = 'another-fake-requester'
default_config = {'base_url': 'https://another-fake.example.com'}
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results."""
return [{'index': i, 'relevance_score': 0.9 - i * 0.1} for i in range(len(documents))]
def _create_fake_component(name: str, requester_class: type) -> Mock:
"""Create a fake Component mock for a requester."""
# Use Mock to allow overriding get_python_component_class
component = Mock(spec=discover_engine.Component)
component.metadata = Mock()
component.metadata.name = name
component.get_python_component_class = Mock(return_value=requester_class)
return component
def _make_mock_result(items: list = None, first_item=None):
"""Create a mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
def _make_row_mock(entity):
"""Create a mock Row-like object that can be unpacked via _mapping.
Note: This function returns the actual entity directly since Mock objects
don't pass isinstance(provider_info, sqlalchemy.Row) checks. The code
in modelmgr.load_provider handles this via the else branch.
"""
return entity
@pytest.fixture
def mock_app_for_modelmgr():
"""Provides a mock Application for ModelManager tests."""
app = SimpleNamespace()
app.logger = Mock()
app.logger.debug = Mock()
app.logger.info = Mock()
app.logger.warning = Mock()
app.logger.error = Mock()
# Fake persistence manager - returns empty results by default
app.persistence_mgr = SimpleNamespace()
async def default_execute(query):
return _make_mock_result([])
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
# Fake discover engine
app.discover = SimpleNamespace()
app.discover.get_components_by_kind = Mock(return_value=[])
# Fake instance config
app.instance_config = SimpleNamespace()
app.instance_config.data = {'space': {'disable_models_service': True}}
# Other services (not used in basic tests)
app.space_service = AsyncMock()
app.llm_model_service = AsyncMock()
app.embedding_models_service = AsyncMock()
app.monitoring_service = AsyncMock()
return app
@pytest.fixture
def fake_requester_registry(mock_app_for_modelmgr):
"""Provides a ModelManager with fake requester registry."""
app = mock_app_for_modelmgr
# Create fake components
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
app.discover.get_components_by_kind = Mock(
return_value=[fake_component, another_component]
)
model_mgr = ModelManager(app)
return model_mgr
@pytest.fixture
def fake_persistence_data():
"""Provides fake persistence data for models and providers."""
provider_uuid = 'test-provider-uuid'
provider_uuid2 = 'test-provider-uuid-2'
providers = [
persistence_model.ModelProvider(
uuid=provider_uuid,
name='Test Provider',
requester='fake-requester',
base_url='https://test.example.com',
api_keys=['test-api-key-1', 'test-api-key-2'],
),
persistence_model.ModelProvider(
uuid=provider_uuid2,
name='Test Provider 2',
requester='another-fake-requester',
base_url='https://test2.example.com',
api_keys=['key-3'],
),
]
llm_models = [
persistence_model.LLMModel(
uuid='test-llm-uuid-1',
name='TestLLM-1',
provider_uuid=provider_uuid,
abilities=['func_call'],
extra_args={'temperature': 0.7},
),
persistence_model.LLMModel(
uuid='test-llm-uuid-2',
name='TestLLM-2',
provider_uuid=provider_uuid,
abilities=['vision'],
extra_args={},
),
]
embedding_models = [
persistence_model.EmbeddingModel(
uuid='test-embedding-uuid-1',
name='TestEmbedding-1',
provider_uuid=provider_uuid,
extra_args={'dimensions': 768},
),
]
rerank_models = [
persistence_model.RerankModel(
uuid='test-rerank-uuid-1',
name='TestRerank-1',
provider_uuid=provider_uuid2,
extra_args={},
),
]
return {
'providers': providers,
'llm_models': llm_models,
'embedding_models': embedding_models,
'rerank_models': rerank_models,
'provider_uuid': provider_uuid,
'provider_uuid2': provider_uuid2,
}
@pytest.fixture
def runtime_provider(fake_persistence_data, mock_app_for_modelmgr):
"""Provides a RuntimeProvider instance for testing."""
provider_entity = fake_persistence_data['providers'][0]
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
requester_inst = FakeProviderAPIRequester(mock_app_for_modelmgr, {'base_url': provider_entity.base_url})
return requester.RuntimeProvider(
provider_entity=provider_entity,
token_mgr=token_mgr,
requester=requester_inst,
)
@pytest.fixture
def runtime_llm_model(fake_persistence_data, runtime_provider):
"""Provides a RuntimeLLMModel instance for testing."""
model_entity = fake_persistence_data['llm_models'][0]
return requester.RuntimeLLMModel(
model_entity=model_entity,
provider=runtime_provider,
)
@pytest.fixture
def runtime_embedding_model(fake_persistence_data, runtime_provider):
"""Provides a RuntimeEmbeddingModel instance for testing."""
model_entity = fake_persistence_data['embedding_models'][0]
return requester.RuntimeEmbeddingModel(
model_entity=model_entity,
provider=runtime_provider,
)
@pytest.fixture
def runtime_rerank_model(fake_persistence_data, mock_app_for_modelmgr):
"""Provides a RuntimeRerankModel instance for testing."""
provider_entity = fake_persistence_data['providers'][1]
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
requester_inst = AnotherFakeRequester(mock_app_for_modelmgr, {'base_url': provider_entity.base_url})
provider = requester.RuntimeProvider(
provider_entity=provider_entity,
token_mgr=token_mgr,
requester=requester_inst,
)
model_entity = fake_persistence_data['rerank_models'][0]
return requester.RuntimeRerankModel(
model_entity=model_entity,
provider=provider,
)
@@ -0,0 +1,32 @@
"""Tests for AnthropicMessages requester.
Tests config and pure utility methods.
"""
from __future__ import annotations
from unittest.mock import MagicMock
class TestAnthropicMessagesConfig:
"""Tests for default config."""
def test_default_config_values(self):
"""Check default_config."""
from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages
assert AnthropicMessages.default_config['base_url'] == 'https://api.anthropic.com'
assert AnthropicMessages.default_config['timeout'] == 120
def test_config_override(self):
"""Config can override defaults."""
from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages
mock_app = MagicMock()
req = AnthropicMessages(mock_app, {
'base_url': 'https://custom.anthropic.com',
'timeout': 60,
})
assert req.requester_cfg['base_url'] == 'https://custom.anthropic.com'
assert req.requester_cfg['timeout'] == 60
@@ -0,0 +1,247 @@
"""Tests for requester error handling - direct import version.
Tests error handling branches by importing real packages and mocking
only the necessary dependencies.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
import openai # Import real openai package
from langbot.pkg.provider.modelmgr.errors import RequesterError
class TestInvokeLLMErrorHandling:
"""Tests for invoke_llm error handling branches."""
@pytest.fixture
def mock_app(self):
"""Create mock Application."""
app = MagicMock()
app.tool_mgr = MagicMock()
app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[])
return app
@pytest.fixture
def mock_model(self):
"""Create mock RuntimeLLMModel."""
model = MagicMock()
model.model_entity = MagicMock()
model.model_entity.name = 'gpt-4'
model.provider = MagicMock()
model.provider.token_mgr = MagicMock()
model.provider.token_mgr.get_token = MagicMock(return_value='test-key')
return model
@pytest.fixture
def mock_message(self):
"""Create mock provider message."""
msg = MagicMock()
msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'})
return msg
@pytest.fixture
def requester_with_mocked_client(self, mock_app):
"""Create requester with mocked OpenAI client."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
req = OpenAIChatCompletions(mock_app, {
'base_url': 'https://api.openai.com/v1',
'timeout': 120,
})
# Replace client with mock
req.client = MagicMock()
req.client.chat = MagicMock()
req.client.chat.completions = MagicMock()
req.client.chat.completions.create = AsyncMock()
return req
@pytest.mark.asyncio
async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message):
"""TimeoutError is wrapped as RequesterError."""
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
side_effect=asyncio.TimeoutError()
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_llm(
query=None,
model=mock_model,
messages=[mock_message],
)
assert '超时' in str(exc.value)
@pytest.mark.asyncio
async def test_bad_request_context_length(self, requester_with_mocked_client, mock_model, mock_message):
"""BadRequestError with context_length_exceeded has special message."""
error = openai.BadRequestError(
message='context_length_exceeded: max 4096',
response=MagicMock(status_code=400),
body={}
)
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
side_effect=error
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_llm(
query=None,
model=mock_model,
messages=[mock_message],
)
assert '上文过长' in str(exc.value)
@pytest.mark.asyncio
async def test_authentication_error(self, requester_with_mocked_client, mock_model, mock_message):
"""AuthenticationError shows invalid api-key message."""
error = openai.AuthenticationError(
message='Invalid API key',
response=MagicMock(status_code=401),
body={}
)
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
side_effect=error
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_llm(
query=None,
model=mock_model,
messages=[mock_message],
)
assert 'api-key' in str(exc.value).lower() or '无效' in str(exc.value)
@pytest.mark.asyncio
async def test_rate_limit_error(self, requester_with_mocked_client, mock_model, mock_message):
"""RateLimitError shows rate limit message."""
error = openai.RateLimitError(
message='Rate limit exceeded',
response=MagicMock(status_code=429),
body={}
)
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
side_effect=error
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_llm(
query=None,
model=mock_model,
messages=[mock_message],
)
assert '频繁' in str(exc.value) or '余额' in str(exc.value)
class TestInvokeEmbeddingErrorHandling:
"""Tests for invoke_embedding error handling."""
@pytest.fixture
def mock_app(self):
return MagicMock()
@pytest.fixture
def mock_embedding_model(self):
model = MagicMock()
model.model_entity = MagicMock()
model.model_entity.name = 'text-embedding-ada-002'
model.model_entity.extra_args = {}
model.provider = MagicMock()
model.provider.token_mgr = MagicMock()
model.provider.token_mgr.get_token = MagicMock(return_value='test-key')
return model
@pytest.fixture
def requester_with_mocked_client(self, mock_app):
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
req = OpenAIChatCompletions(mock_app, {})
req.client = MagicMock()
req.client.embeddings = MagicMock()
req.client.embeddings.create = AsyncMock()
return req
@pytest.mark.asyncio
async def test_embedding_timeout_error(self, requester_with_mocked_client, mock_embedding_model):
"""TimeoutError in embedding request."""
requester_with_mocked_client.client.embeddings.create = AsyncMock(
side_effect=asyncio.TimeoutError()
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_embedding(
model=mock_embedding_model,
input_text=['test'],
)
assert '超时' in str(exc.value)
@pytest.mark.asyncio
async def test_embedding_bad_request_error(self, requester_with_mocked_client, mock_embedding_model):
"""BadRequestError in embedding request."""
error = openai.BadRequestError(
message='Invalid model',
response=MagicMock(status_code=400),
body={}
)
requester_with_mocked_client.client.embeddings.create = AsyncMock(
side_effect=error
)
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_embedding(
model=mock_embedding_model,
input_text=['test'],
)
assert '参数' in str(exc.value)
class TestRequesterErrorClass:
"""Tests for RequesterError."""
def test_error_message_prefix(self):
"""RequesterError has '模型请求失败' prefix."""
from langbot.pkg.provider.modelmgr.errors import RequesterError
error = RequesterError('test error')
assert '模型请求失败' in str(error)
def test_error_is_exception(self):
"""RequesterError inherits Exception."""
from langbot.pkg.provider.modelmgr.errors import RequesterError
error = RequesterError('test')
assert isinstance(error, Exception)
class TestDefaultConfig:
"""Tests for requester default config."""
def test_default_config(self):
"""Check default_config values."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
assert OpenAIChatCompletions.default_config['base_url'] == 'https://api.openai.com/v1'
assert OpenAIChatCompletions.default_config['timeout'] == 120
def test_config_override(self):
"""Config overrides defaults."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
req = OpenAIChatCompletions(MagicMock(), {
'base_url': 'https://custom.com/v1',
'timeout': 60,
})
assert req.requester_cfg['base_url'] == 'https://custom.com/v1'
assert req.requester_cfg['timeout'] == 60
@@ -0,0 +1,340 @@
"""Tests for requester pure utility functions.
Tests the helper methods in OpenAIChatCompletions that don't require network calls.
"""
from __future__ import annotations
from unittest.mock import MagicMock
from tests.utils.import_isolation import isolated_sys_modules
class TestMaskApiKey:
"""Tests for _mask_api_key method."""
def _create_requester_with_mocks(self):
"""Create requester instance with mocked dependencies."""
mocks = {
'langbot.pkg.core.app': MagicMock(),
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
}
with isolated_sys_modules(mocks):
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
mock_app = MagicMock()
requester = OpenAIChatCompletions(mock_app, {})
return requester
def test_mask_api_key_full(self):
"""Mask a full API key."""
requester = self._create_requester_with_mocks()
result = requester._mask_api_key('sk-1234567890abcdef')
assert result == 'sk-1...cdef'
def test_mask_api_key_short(self):
"""Mask a short API key (<=8 chars)."""
requester = self._create_requester_with_mocks()
result = requester._mask_api_key('short')
assert result == '****'
def test_mask_api_key_empty(self):
"""Empty API key returns empty string."""
requester = self._create_requester_with_mocks()
result = requester._mask_api_key('')
assert result == ''
def test_mask_api_key_none(self):
"""None API key returns empty string."""
requester = self._create_requester_with_mocks()
result = requester._mask_api_key(None)
assert result == ''
def test_mask_api_key_exact_8_chars(self):
"""API key with exactly 8 chars is masked as **** (<=8 threshold)."""
requester = self._create_requester_with_mocks()
result = requester._mask_api_key('12345678')
assert result == '****' # <= 8 chars gets masked
class TestInferModelType:
"""Tests for _infer_model_type method."""
def _create_requester_with_mocks(self):
mocks = {
'langbot.pkg.core.app': MagicMock(),
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
}
with isolated_sys_modules(mocks):
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
mock_app = MagicMock()
requester = OpenAIChatCompletions(mock_app, {})
return requester
def test_infer_embedding_from_name(self):
"""Infer embedding type from model name."""
requester = self._create_requester_with_mocks()
assert requester._infer_model_type('text-embedding-ada-002') == 'embedding'
assert requester._infer_model_type('bge-large-en') == 'embedding'
assert requester._infer_model_type('e5-base') == 'embedding'
assert requester._infer_model_type('m3e-base') == 'embedding'
def test_infer_llm_from_name(self):
"""Infer LLM type from model name."""
requester = self._create_requester_with_mocks()
assert requester._infer_model_type('gpt-4') == 'llm'
assert requester._infer_model_type('claude-3-opus') == 'llm'
assert requester._infer_model_type('llama-2-70b') == 'llm'
def test_infer_model_type_none_id(self):
"""Handle None model_id."""
requester = self._create_requester_with_mocks()
result = requester._infer_model_type(None)
assert result == 'llm' # Default
def test_infer_model_type_empty_id(self):
"""Handle empty model_id."""
requester = self._create_requester_with_mocks()
result = requester._infer_model_type('')
assert result == 'llm' # Default
class TestNormalizeModalities:
"""Tests for _normalize_modalities method."""
def _create_requester_with_mocks(self):
mocks = {
'langbot.pkg.core.app': MagicMock(),
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
}
with isolated_sys_modules(mocks):
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
mock_app = MagicMock()
requester = OpenAIChatCompletions(mock_app, {})
return requester
def test_normalize_string_modality(self):
"""Normalize single string modality."""
requester = self._create_requester_with_mocks()
result = requester._normalize_modalities('text,image')
assert result == ['text', 'image']
def test_normalize_list_modalities(self):
"""Normalize list of modalities."""
requester = self._create_requester_with_mocks()
result = requester._normalize_modalities(['text', 'image', 'audio'])
assert result == ['text', 'image', 'audio']
def test_normalize_dict_modalities(self):
"""Normalize dict with nested modalities."""
requester = self._create_requester_with_mocks()
result = requester._normalize_modalities({'input': ['text'], 'output': ['text', 'image']})
assert result == ['text', 'image']
def test_normalize_none(self):
"""Handle None input."""
requester = self._create_requester_with_mocks()
result = requester._normalize_modalities(None)
assert result == []
def test_normalize_arrow_separator(self):
"""Handle arrow separator in modality string."""
requester = self._create_requester_with_mocks()
result = requester._normalize_modalities('text->image')
assert result == ['text', 'image']
class TestParseRerankResponse:
"""Tests for _parse_rerank_response static method."""
def test_parse_cohere_jina_format(self):
"""Parse Cohere/Jina/SiliconFlow format."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
data = {
'results': [
{'index': 0, 'relevance_score': 0.95},
{'index': 1, 'relevance_score': 0.80},
]
}
result = OpenAIChatCompletions._parse_rerank_response(data)
assert result == [
{'index': 0, 'relevance_score': 0.95},
{'index': 1, 'relevance_score': 0.80},
]
def test_parse_voyage_format(self):
"""Parse Voyage AI format."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
data = {
'data': [
{'index': 0, 'relevance_score': 0.90},
{'index': 2, 'relevance_score': 0.75},
]
}
result = OpenAIChatCompletions._parse_rerank_response(data)
assert result == [
{'index': 0, 'relevance_score': 0.90},
{'index': 2, 'relevance_score': 0.75},
]
def test_parse_dashscope_format(self):
"""Parse DashScope format."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
data = {
'output': {
'results': [
{'index': 0, 'relevance_score': 0.85},
]
}
}
result = OpenAIChatCompletions._parse_rerank_response(data)
assert result == [{'index': 0, 'relevance_score': 0.85}]
def test_parse_unknown_format(self):
"""Handle unknown format returns empty list."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
data = {'unknown_key': 'value'}
result = OpenAIChatCompletions._parse_rerank_response(data)
assert result == []
def test_parse_empty_results(self):
"""Handle empty results."""
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
data = {'results': []}
result = OpenAIChatCompletions._parse_rerank_response(data)
assert result == []
class TestExtractScanMetadata:
"""Tests for _extract_scan_metadata method."""
def _create_requester_with_mocks(self):
mocks = {
'langbot.pkg.core.app': MagicMock(),
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
}
with isolated_sys_modules(mocks):
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
mock_app = MagicMock()
requester = OpenAIChatCompletions(mock_app, {})
return requester
def test_extract_basic_metadata(self):
"""Extract basic model metadata."""
requester = self._create_requester_with_mocks()
item = {
'id': 'gpt-4',
'name': 'GPT-4 Turbo',
'description': 'Most capable GPT-4 model',
'context_length': 128000,
'owned_by': 'openai',
}
result = requester._extract_scan_metadata(item, 'gpt-4')
assert result['display_name'] == 'GPT-4 Turbo'
assert result['description'] == 'Most capable GPT-4 model'
assert result['context_length'] == 128000
assert result['owned_by'] == 'openai'
def test_extract_metadata_missing_fields(self):
"""Handle missing metadata fields."""
requester = self._create_requester_with_mocks()
item = {'id': 'unknown-model'}
result = requester._extract_scan_metadata(item, 'unknown-model')
assert result['display_name'] is None
assert result['description'] is None
assert result['context_length'] is None
assert result['owned_by'] is None
def test_extract_metadata_top_provider_context(self):
"""Extract context_length from top_provider."""
requester = self._create_requester_with_mocks()
item = {
'id': 'model',
'top_provider': {
'context_length': 4096,
},
}
result = requester._extract_scan_metadata(item, 'model')
assert result['context_length'] == 4096
def test_extract_metadata_empty_strings(self):
"""Handle empty string values."""
requester = self._create_requester_with_mocks()
item = {
'id': 'model',
'name': '', # Empty name
'description': ' ', # Whitespace only
'owned_by': '',
}
result = requester._extract_scan_metadata(item, 'model')
assert result['display_name'] is None
assert result['description'] is None
assert result['owned_by'] is None
def test_extract_metadata_name_matches_id(self):
"""When name equals id, display_name is None."""
requester = self._create_requester_with_mocks()
item = {
'id': 'gpt-4',
'name': 'gpt-4', # Same as id
}
result = requester._extract_scan_metadata(item, 'gpt-4')
assert result['display_name'] is None
@@ -0,0 +1,264 @@
"""Tests for OllamaChatCompletions requester.
Tests model inference, payload construction, and error handling.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from langbot.pkg.provider.modelmgr.errors import RequesterError
class TestOllamaRequesterConfig:
"""Tests for default config."""
def test_default_config_values(self):
"""Check default_config."""
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
assert OllamaChatCompletions.default_config['base_url'] == 'http://127.0.0.1:11434'
assert OllamaChatCompletions.default_config['timeout'] == 120
def test_config_override(self):
"""Config can override defaults."""
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
mock_app = MagicMock()
req = OllamaChatCompletions(mock_app, {
'base_url': 'http://custom.ollama:11434',
'timeout': 300,
})
assert req.requester_cfg['base_url'] == 'http://custom.ollama:11434'
assert req.requester_cfg['timeout'] == 300
class TestOllamaInferModelType:
"""Tests for _infer_model_type pure function."""
@pytest.fixture
def requester(self):
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
return OllamaChatCompletions(MagicMock(), {})
def test_infer_embedding_from_name(self, requester):
"""Embedding keywords return 'embedding'."""
assert requester._infer_model_type('nomic-embed-text') == 'embedding'
assert requester._infer_model_type('bge-large') == 'embedding'
assert requester._infer_model_type('text-embedding') == 'embedding'
def test_infer_llm_from_name(self, requester):
"""Non-embedding keywords return 'llm'."""
assert requester._infer_model_type('llama2') == 'llm'
assert requester._infer_model_type('mistral') == 'llm'
assert requester._infer_model_type('codellama') == 'llm'
def test_infer_model_type_none(self, requester):
"""None model_id returns 'llm'."""
assert requester._infer_model_type(None) == 'llm'
def test_infer_model_type_empty(self, requester):
"""Empty model_id returns 'llm'."""
assert requester._infer_model_type('') == 'llm'
class TestOllamaInferModelAbilities:
"""Tests for _infer_model_abilities pure function."""
@pytest.fixture
def requester(self):
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
return OllamaChatCompletions(MagicMock(), {})
def test_infer_vision_ability(self, requester):
"""Vision keywords add 'vision' ability."""
item = {
'details': {
'family': 'llava',
}
}
abilities = requester._infer_model_abilities(item, 'llava-v1.5')
assert 'vision' in abilities
def test_infer_vision_from_model_id(self, requester):
"""Vision keywords in model_id add 'vision' ability."""
item = {}
abilities = requester._infer_model_abilities(item, 'llava-7b')
assert 'vision' in abilities
def test_infer_func_call_ability(self, requester):
"""Tool/function keywords add 'func_call' ability."""
item = {
'details': {
'families': ['tools'],
}
}
abilities = requester._infer_model_abilities(item, 'model')
assert 'func_call' in abilities
def test_infer_no_abilities(self, requester):
"""No matching keywords returns empty abilities."""
item = {
'details': {
'family': 'llama',
}
}
abilities = requester._infer_model_abilities(item, 'llama-2')
assert len(abilities) == 0
def test_infer_multiple_abilities(self, requester):
"""Multiple keywords can add multiple abilities."""
item = {
'details': {
'family': 'vision',
'families': ['tools'],
}
}
abilities = requester._infer_model_abilities(item, 'vision-tool-model')
assert 'vision' in abilities
assert 'func_call' in abilities
class TestOllamaMakeMessage:
"""Tests for _make_msg response parsing."""
@pytest.fixture
def requester(self):
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
return OllamaChatCompletions(MagicMock(), {})
def _create_ollama_response(self, content, tool_calls=None):
"""Helper to create mock ollama response."""
import ollama
mock_response = MagicMock(spec=ollama.ChatResponse)
mock_message = MagicMock(spec=ollama.Message)
mock_message.content = content
mock_message.tool_calls = tool_calls
mock_response.message = mock_message
return mock_response
@pytest.mark.asyncio
async def test_make_msg_text_content(self, requester):
"""Text content is extracted."""
mock_response = self._create_ollama_response('Hello world')
result = await requester._make_msg(mock_response)
assert result.content == 'Hello world'
assert result.role == 'assistant'
@pytest.mark.asyncio
async def test_make_msg_with_tool_calls(self, requester):
"""Tool calls are parsed."""
mock_tool_call = MagicMock()
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = 'get_weather'
mock_tool_call.function.arguments = {'location': 'Beijing'}
mock_response = self._create_ollama_response('', tool_calls=[mock_tool_call])
result = await requester._make_msg(mock_response)
assert result.tool_calls is not None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == 'get_weather'
# Arguments should be JSON string
assert isinstance(result.tool_calls[0].function.arguments, str)
@pytest.mark.asyncio
async def test_make_msg_empty_message_raises(self, requester):
"""Empty message raises ValueError."""
mock_response = MagicMock()
mock_response.message = None
with pytest.raises(ValueError, match='message'):
await requester._make_msg(mock_response)
class TestOllamaErrorHandling:
"""Tests for error handling branches."""
@pytest.fixture
def mock_app(self):
app = MagicMock()
app.tool_mgr = MagicMock()
app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[])
return app
@pytest.fixture
def requester_with_mocked_client(self, mock_app):
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
req = OllamaChatCompletions(mock_app, {})
req.client = MagicMock()
req.client.chat = AsyncMock()
return req
@pytest.fixture
def mock_model(self):
model = MagicMock()
model.model_entity = MagicMock()
model.model_entity.name = 'llama2'
model.provider = MagicMock()
model.provider.token_mgr = MagicMock()
model.provider.token_mgr.get_token = MagicMock(return_value='')
return model
@pytest.fixture
def mock_message(self):
msg = MagicMock()
msg.role = 'user'
msg.content = 'test'
msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'})
return msg
@pytest.mark.asyncio
async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message):
"""TimeoutError is converted to RequesterError."""
requester_with_mocked_client.client.chat = AsyncMock(side_effect=asyncio.TimeoutError())
with pytest.raises(RequesterError) as exc:
await requester_with_mocked_client.invoke_llm(
query=None,
model=mock_model,
messages=[mock_message],
)
assert '超时' in str(exc.value)
class TestOllamaScanModels:
"""Tests for scan_models method."""
@pytest.fixture
def mock_app(self):
return MagicMock()
@pytest.fixture
def requester(self, mock_app):
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
req = OllamaChatCompletions(mock_app, {
'base_url': 'http://127.0.0.1:11434',
'timeout': 120,
})
return req
def test_requester_name_constant(self):
"""REQUESTER_NAME constant exists."""
from langbot.pkg.provider.modelmgr.requesters.ollamachat import REQUESTER_NAME
assert REQUESTER_NAME == 'ollama-chat'
@@ -0,0 +1,169 @@
"""Tests for DifyServiceAPIRunner pure utility methods.
Tests the helper methods that don't require real Dify API calls.
"""
from __future__ import annotations
import pytest
class TestDifyExtractTextOutput:
"""Tests for _extract_dify_text_output method."""
def _create_runner(self):
"""Create runner instance."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'chat',
'api-key': 'test-key',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
runner.dify_client = MagicMock()
return runner
def test_extract_none_value(self):
"""None returns empty string."""
runner = self._create_runner()
result = runner._extract_dify_text_output(None)
assert result == ''
def test_extract_string_value(self):
"""Plain string is returned."""
runner = self._create_runner()
result = runner._extract_dify_text_output('plain text')
assert result == 'plain text'
def test_extract_dict_with_content(self):
"""Dict with 'content' key extracts content."""
runner = self._create_runner()
result = runner._extract_dify_text_output({'content': 'extracted content'})
assert result == 'extracted content'
def test_extract_dict_without_content(self):
"""Dict without 'content' key is JSON dumped."""
runner = self._create_runner()
result = runner._extract_dify_text_output({'key': 'value'})
assert 'key' in result
assert 'value' in result
def test_extract_json_string_with_content(self):
"""JSON string with 'content' key extracts content."""
runner = self._create_runner()
result = runner._extract_dify_text_output('{"content": "json content"}')
assert result == 'json content'
def test_extract_json_string_without_content(self):
"""JSON string without 'content' key returns original."""
runner = self._create_runner()
result = runner._extract_dify_text_output('{"other": "value"}')
assert '{"other": "value"}' in result
def test_extract_whitespace_string(self):
"""Whitespace string returns empty."""
runner = self._create_runner()
result = runner._extract_dify_text_output(' ')
assert result == ''
class TestDifyRunnerConfigValidation:
"""Tests for runner config validation."""
def test_invalid_app_type_raises(self):
"""Invalid app-type raises DifyAPIError."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
from langbot.libs.dify_service_api.v1.errors import DifyAPIError
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'invalid-type',
'api-key': 'test',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
with pytest.raises(DifyAPIError, match='不支持'):
DifyServiceAPIRunner(mock_app, pipeline_config)
def test_valid_app_types(self):
"""Valid app-types don't raise."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
for app_type in ['chat', 'agent', 'workflow']:
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': app_type,
'api-key': 'test',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
# Should not raise
assert runner is not None
class TestDifyRunnerInit:
"""Tests for runner initialization."""
def test_runner_stores_config(self):
"""Runner stores pipeline_config."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'chat',
'api-key': 'test-key',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app
@@ -0,0 +1,788 @@
"""
Unit tests for ModelManager in provider/modelmgr.
Tests model configuration management, requester selection, provider loading,
and error handling without calling real LLM APIs.
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
from langbot.pkg.provider.modelmgr import requester
from langbot.pkg.entity.persistence import model as persistence_model
from langbot.pkg.entity.errors import provider as provider_errors
from langbot.pkg.provider.modelmgr import token
from tests.unit_tests.provider.conftest import _make_mock_result, _make_row_mock
# ============================================================================
# ModelManager Initialization Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_initialize_with_fake_requesters(fake_requester_registry):
"""Test ModelManager initializes with fake requester registry."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
assert 'fake-requester' in model_mgr.requester_dict
assert 'another-fake-requester' in model_mgr.requester_dict
assert model_mgr.requester_dict['fake-requester'] is not None
assert len(model_mgr.requester_components) == 2
@pytest.mark.asyncio
async def test_model_manager_initialize_empty_registry(mock_app_for_modelmgr):
"""Test ModelManager handles empty requester registry."""
app = mock_app_for_modelmgr
app.discover.get_components_by_kind = Mock(return_value=[])
model_mgr = ModelManager(app)
await model_mgr.initialize()
assert model_mgr.requester_dict == {}
assert len(model_mgr.requester_components) == 0
@pytest.mark.asyncio
async def test_model_manager_skips_space_sync_when_disabled(mock_app_for_modelmgr):
"""Test ModelManager skips space sync when disabled in config."""
app = mock_app_for_modelmgr
app.instance_config.data = {'space': {'disable_models_service': True}}
model_mgr = ModelManager(app)
await model_mgr.initialize()
# Should not call space_service if disabled
app.space_service.get_models.assert_not_called()
# ============================================================================
# Model Loading Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_load_models_from_db(fake_requester_registry, fake_persistence_data):
"""Test ModelManager loads models from database correctly."""
model_mgr = fake_requester_registry
# Setup fake persistence responses - return entities directly (code handles non-Row entities)
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'llm_models' in query_str:
return _make_mock_result(fake_persistence_data['llm_models'])
elif 'embedding_models' in query_str:
return _make_mock_result(fake_persistence_data['embedding_models'])
elif 'rerank_models' in query_str:
return _make_mock_result(fake_persistence_data['rerank_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
# Check providers loaded
assert len(model_mgr.provider_dict) == 2
assert fake_persistence_data['provider_uuid'] in model_mgr.provider_dict
assert fake_persistence_data['provider_uuid2'] in model_mgr.provider_dict
# Check models loaded
assert len(model_mgr.llm_models) == 2
assert len(model_mgr.embedding_models) == 1
assert len(model_mgr.rerank_models) == 1
@pytest.mark.asyncio
async def test_model_manager_load_provider_unknown_requester(mock_app_for_modelmgr):
"""Test ModelManager raises RequesterNotFoundError for unknown requester."""
app = mock_app_for_modelmgr
app.discover.get_components_by_kind = Mock(return_value=[])
model_mgr = ModelManager(app)
await model_mgr.initialize()
provider_info = {
'uuid': 'unknown-provider',
'name': 'Unknown Provider',
'requester': 'non-existent-requester',
'base_url': 'https://unknown.com',
'api_keys': [],
}
with pytest.raises(provider_errors.RequesterNotFoundError) as exc_info:
await model_mgr.load_provider(provider_info)
assert exc_info.value.requester_name == 'non-existent-requester'
@pytest.mark.asyncio
async def test_model_manager_load_provider_from_dict(fake_requester_registry):
"""Test ModelManager loads provider from dict correctly."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
provider_info = {
'uuid': 'dict-provider-uuid',
'name': 'Dict Provider',
'requester': 'fake-requester',
'base_url': 'https://dict.example.com',
'api_keys': ['dict-key'],
}
runtime_provider = await model_mgr.load_provider(provider_info)
assert runtime_provider.provider_entity.uuid == 'dict-provider-uuid'
assert runtime_provider.provider_entity.name == 'Dict Provider'
assert runtime_provider.token_mgr.name == 'dict-provider-uuid'
assert runtime_provider.token_mgr.tokens == ['dict-key']
assert isinstance(runtime_provider.requester, requester.ProviderAPIRequester)
@pytest.mark.asyncio
async def test_model_manager_load_provider_from_entity(fake_requester_registry, fake_persistence_data):
"""Test ModelManager loads provider from persistence entity."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
provider_entity = fake_persistence_data['providers'][0]
runtime_provider = await model_mgr.load_provider(provider_entity)
assert runtime_provider.provider_entity.uuid == provider_entity.uuid
assert runtime_provider.requester is not None
# ============================================================================
# Model Query Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_get_model_by_uuid(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.get_model_by_uuid returns correct model."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'llm_models' in query_str:
return _make_mock_result(fake_persistence_data['llm_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
model = await model_mgr.get_model_by_uuid('test-llm-uuid-1')
assert model.model_entity.uuid == 'test-llm-uuid-1'
assert model.model_entity.name == 'TestLLM-1'
@pytest.mark.asyncio
async def test_model_manager_get_model_by_uuid_not_found(fake_requester_registry):
"""Test ModelManager.get_model_by_uuid raises ValueError for unknown model."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
with pytest.raises(ValueError) as exc_info:
await model_mgr.get_model_by_uuid('unknown-model-uuid')
assert 'unknown-model-uuid' in str(exc_info.value)
@pytest.mark.asyncio
async def test_model_manager_get_embedding_model_by_uuid(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.get_embedding_model_by_uuid returns correct model."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'embedding_models' in query_str:
return _make_mock_result(fake_persistence_data['embedding_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
model = await model_mgr.get_embedding_model_by_uuid('test-embedding-uuid-1')
assert model.model_entity.uuid == 'test-embedding-uuid-1'
@pytest.mark.asyncio
async def test_model_manager_get_embedding_model_by_uuid_not_found(fake_requester_registry):
"""Test ModelManager.get_embedding_model_by_uuid raises ValueError."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
with pytest.raises(ValueError):
await model_mgr.get_embedding_model_by_uuid('unknown-embedding-uuid')
@pytest.mark.asyncio
async def test_model_manager_get_rerank_model_by_uuid(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.get_rerank_model_by_uuid returns correct model."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'rerank_models' in query_str:
return _make_mock_result(fake_persistence_data['rerank_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
model = await model_mgr.get_rerank_model_by_uuid('test-rerank-uuid-1')
assert model.model_entity.uuid == 'test-rerank-uuid-1'
@pytest.mark.asyncio
async def test_model_manager_get_rerank_model_by_uuid_not_found(fake_requester_registry):
"""Test ModelManager.get_rerank_model_by_uuid raises ValueError."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
with pytest.raises(ValueError):
await model_mgr.get_rerank_model_by_uuid('unknown-rerank-uuid')
# ============================================================================
# Model Removal Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_remove_llm_model(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.remove_llm_model removes model correctly."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'llm_models' in query_str:
return _make_mock_result(fake_persistence_data['llm_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
assert len(model_mgr.llm_models) == 2
await model_mgr.remove_llm_model('test-llm-uuid-1')
assert len(model_mgr.llm_models) == 1
assert model_mgr.llm_models[0].model_entity.uuid == 'test-llm-uuid-2'
@pytest.mark.asyncio
async def test_model_manager_remove_llm_model_not_found(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.remove_llm_model handles unknown model gracefully."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'llm_models' in query_str:
return _make_mock_result(fake_persistence_data['llm_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
original_count = len(model_mgr.llm_models)
# Removing unknown model should do nothing (no error)
await model_mgr.remove_llm_model('unknown-model-uuid')
assert len(model_mgr.llm_models) == original_count
@pytest.mark.asyncio
async def test_model_manager_remove_embedding_model(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.remove_embedding_model removes model correctly."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'embedding_models' in query_str:
return _make_mock_result(fake_persistence_data['embedding_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
assert len(model_mgr.embedding_models) == 1
await model_mgr.remove_embedding_model('test-embedding-uuid-1')
assert len(model_mgr.embedding_models) == 0
@pytest.mark.asyncio
async def test_model_manager_remove_rerank_model(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.remove_rerank_model removes model correctly."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'rerank_models' in query_str:
return _make_mock_result(fake_persistence_data['rerank_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
assert len(model_mgr.rerank_models) == 1
await model_mgr.remove_rerank_model('test-rerank-uuid-1')
assert len(model_mgr.rerank_models) == 0
@pytest.mark.asyncio
async def test_model_manager_remove_provider(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.remove_provider removes provider correctly."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
return _make_mock_result(fake_persistence_data['providers'])
elif 'llm_models' in query_str:
return _make_mock_result(fake_persistence_data['llm_models'])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
assert fake_persistence_data['provider_uuid'] in model_mgr.provider_dict
await model_mgr.remove_provider(fake_persistence_data['provider_uuid'])
assert fake_persistence_data['provider_uuid'] not in model_mgr.provider_dict
# ============================================================================
# Requester Info Tests
# ============================================================================
def test_model_manager_get_available_requesters_info(fake_requester_registry):
"""Test ModelManager.get_available_requesters_info returns correct info."""
model_mgr = fake_requester_registry
model_mgr.requester_components = []
info = model_mgr.get_available_requesters_info('')
assert info == []
def test_model_manager_get_available_requesters_info_with_type_filter(fake_requester_registry):
"""Test ModelManager.get_available_requesters_info filters by model type."""
model_mgr = fake_requester_registry
from langbot.pkg.discover import engine as discover_engine
manifest = {
'apiVersion': 'v1',
'kind': 'LLMAPIRequester',
'metadata': {'name': 'test-req', 'label': {'en_US': 'Test'}, 'description': {'en_US': 'Test'}},
'spec': {'support_type': ['chat', 'embedding']},
'execution': {'python': {'path': 'fake', 'attr': 'FakeClass'}},
}
component = discover_engine.Component(owner='test', manifest=manifest, rel_path='fake.yaml')
model_mgr.requester_components = [component]
# Filter by chat type
info = model_mgr.get_available_requesters_info('chat')
assert len(info) == 1
assert info[0]['name'] == 'test-req'
# Filter by unsupported type
info = model_mgr.get_available_requesters_info('rerank')
assert len(info) == 0
def test_model_manager_get_available_requester_info_by_name(fake_requester_registry):
"""Test ModelManager.get_available_requester_info_by_name returns correct info."""
model_mgr = fake_requester_registry
from langbot.pkg.discover import engine as discover_engine
manifest = {
'apiVersion': 'v1',
'kind': 'LLMAPIRequester',
'metadata': {'name': 'named-req', 'label': {'en_US': 'Named'}, 'description': {'en_US': 'Named'}},
'spec': {'support_type': ['chat']},
'execution': {'python': {'path': 'fake', 'attr': 'FakeClass'}},
}
component = discover_engine.Component(owner='test', manifest=manifest, rel_path='fake.yaml')
model_mgr.requester_components = [component]
info = model_mgr.get_available_requester_info_by_name('named-req')
assert info is not None
assert info['name'] == 'named-req'
info = model_mgr.get_available_requester_info_by_name('unknown-req')
assert info is None
def test_model_manager_get_available_requester_manifest_by_name(fake_requester_registry):
"""Test ModelManager.get_available_requester_manifest_by_name returns component."""
model_mgr = fake_requester_registry
from langbot.pkg.discover import engine as discover_engine
manifest = {
'apiVersion': 'v1',
'kind': 'LLMAPIRequester',
'metadata': {'name': 'manifest-req', 'label': {'en_US': 'Manifest'}, 'description': {'en_US': 'Manifest'}},
'spec': {'support_type': ['chat']},
'execution': {'python': {'path': 'fake', 'attr': 'FakeClass'}},
}
component = discover_engine.Component(owner='test', manifest=manifest, rel_path='fake.yaml')
model_mgr.requester_components = [component]
comp = model_mgr.get_available_requester_manifest_by_name('manifest-req')
assert comp is not None
assert comp.metadata.name == 'manifest-req'
comp = model_mgr.get_available_requester_manifest_by_name('unknown-req')
assert comp is None
# ============================================================================
# Temporary Runtime Model Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_registry):
"""Test ModelManager.init_temporary_runtime_llm_model creates model correctly."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
model_info = {
'uuid': 'temp-model-uuid',
'name': 'TempModel',
'provider': {
'uuid': 'temp-provider-uuid',
'name': 'Temp Provider',
'requester': 'fake-requester',
'base_url': 'https://temp.example.com',
'api_keys': ['temp-key'],
},
'abilities': ['func_call'],
'extra_args': {'temperature': 0.5},
}
runtime_model = await model_mgr.init_temporary_runtime_llm_model(model_info)
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
assert runtime_model.model_entity.name == 'TempModel'
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
@pytest.mark.asyncio
async def test_model_manager_init_temporary_runtime_embedding_model(fake_requester_registry):
"""Test ModelManager.init_temporary_runtime_embedding_model creates model correctly."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
model_info = {
'uuid': 'temp-embedding-uuid',
'name': 'TempEmbedding',
'provider': {
'uuid': 'temp-provider-uuid',
'name': 'Temp Provider',
'requester': 'fake-requester',
'base_url': 'https://temp.example.com',
'api_keys': [],
},
'extra_args': {'dimensions': 512},
}
runtime_model = await model_mgr.init_temporary_runtime_embedding_model(model_info)
assert runtime_model.model_entity.uuid == 'temp-embedding-uuid'
assert runtime_model.model_entity.name == 'TempEmbedding'
@pytest.mark.asyncio
async def test_model_manager_init_temporary_runtime_rerank_model(fake_requester_registry):
"""Test ModelManager.init_temporary_runtime_rerank_model creates model correctly."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
model_info = {
'uuid': 'temp-rerank-uuid',
'name': 'TempRerank',
'provider': {
'uuid': 'temp-provider-uuid',
'name': 'Temp Provider',
'requester': 'fake-requester',
'base_url': 'https://temp.example.com',
'api_keys': [],
},
'extra_args': {},
}
runtime_model = await model_mgr.init_temporary_runtime_rerank_model(model_info)
assert runtime_model.model_entity.uuid == 'temp-rerank-uuid'
assert runtime_model.model_entity.name == 'TempRerank'
# ============================================================================
# Provider Reload Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_reload_provider(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.reload_provider reloads provider and updates model refs."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
# For initial load - return all providers
rows = [_make_row_mock(p) for p in fake_persistence_data['providers']]
return _make_mock_result(rows)
elif 'llm_models' in query_str:
rows = [_make_row_mock(m) for m in fake_persistence_data['llm_models']]
return _make_mock_result(rows)
elif 'embedding_models' in query_str:
rows = [_make_row_mock(m) for m in fake_persistence_data['embedding_models']]
return _make_mock_result(rows)
elif 'rerank_models' in query_str:
rows = [_make_row_mock(m) for m in fake_persistence_data['rerank_models']]
return _make_mock_result(rows)
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
original_provider = model_mgr.provider_dict[fake_persistence_data['provider_uuid']]
original_base_url = original_provider.provider_entity.base_url
# Setup for reload - return updated provider
async def reload_execute(query):
updated_provider = persistence_model.ModelProvider(
uuid=fake_persistence_data['provider_uuid'],
name='Updated Provider',
requester='fake-requester',
base_url='https://updated.example.com',
api_keys=['updated-key'],
)
return _make_mock_result([_make_row_mock(updated_provider)], first_item=_make_row_mock(updated_provider))
model_mgr.ap.persistence_mgr.execute_async = reload_execute
await model_mgr.reload_provider(fake_persistence_data['provider_uuid'])
updated_provider = model_mgr.provider_dict[fake_persistence_data['provider_uuid']]
assert updated_provider.provider_entity.base_url == 'https://updated.example.com'
assert updated_provider.provider_entity.base_url != original_base_url
@pytest.mark.asyncio
async def test_model_manager_reload_provider_not_found(fake_requester_registry):
"""Test ModelManager.reload_provider raises ProviderNotFoundError."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
async def fake_execute(query):
return _make_mock_result([], first_item=None)
model_mgr.ap.persistence_mgr.execute_async = fake_execute
with pytest.raises(provider_errors.ProviderNotFoundError) as exc_info:
await model_mgr.reload_provider('unknown-provider-uuid')
assert exc_info.value.provider_name == 'unknown-provider-uuid'
# ============================================================================
# Model Load with Provider Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
model_mgr = fake_requester_registry
model_entity = fake_persistence_data['llm_models'][0]
runtime_model = await model_mgr.load_llm_model_with_provider(model_entity, runtime_provider)
assert runtime_model.model_entity.uuid == model_entity.uuid
assert runtime_model.provider is runtime_provider
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
model_mgr = fake_requester_registry
model_entity = fake_persistence_data['llm_models'][0]
row_mock = _make_row_mock(model_entity)
runtime_model = await model_mgr.load_llm_model_with_provider(row_mock, runtime_provider)
assert runtime_model.model_entity.uuid == model_entity.uuid
@pytest.mark.asyncio
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
model_mgr = fake_requester_registry
model_entity = fake_persistence_data['embedding_models'][0]
runtime_model = await model_mgr.load_embedding_model_with_provider(model_entity, runtime_provider)
assert runtime_model.model_entity.uuid == model_entity.uuid
assert runtime_model.provider is runtime_provider
@pytest.mark.asyncio
async def test_model_manager_load_rerank_model_with_provider(fake_requester_registry, fake_persistence_data):
"""Test ModelManager.load_rerank_model_with_provider creates RuntimeRerankModel."""
model_mgr = fake_requester_registry
await model_mgr.initialize()
provider_entity = fake_persistence_data['providers'][1]
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
requester_inst = model_mgr.requester_dict['another-fake-requester'](
ap=model_mgr.ap, config={'base_url': provider_entity.base_url}
)
await requester_inst.initialize()
provider = requester.RuntimeProvider(
provider_entity=provider_entity,
token_mgr=token_mgr,
requester=requester_inst,
)
model_entity = fake_persistence_data['rerank_models'][0]
runtime_model = await model_mgr.load_rerank_model_with_provider(model_entity, provider)
assert runtime_model.model_entity.uuid == model_entity.uuid
assert runtime_model.provider is provider
# ============================================================================
# Missing Provider Warning Tests
# ============================================================================
@pytest.mark.asyncio
async def test_model_manager_logs_warning_for_missing_provider(fake_requester_registry):
"""Test ModelManager logs warning when model's provider is missing."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
# Return empty providers
return _make_mock_result([])
elif 'llm_models' in query_str:
# Return model with missing provider
fake_model = persistence_model.LLMModel(
uuid='model-with-missing-provider',
name='MissingProviderModel',
provider_uuid='missing-provider-uuid',
abilities=[],
extra_args={},
)
return _make_mock_result([_make_row_mock(fake_model)])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
# Should have logged warning and skipped the model
assert len(model_mgr.llm_models) == 0
model_mgr.ap.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_model_manager_handles_requester_not_found_gracefully(fake_requester_registry):
"""Test ModelManager handles RequesterNotFoundError during provider load."""
model_mgr = fake_requester_registry
async def fake_execute(query):
query_str = str(query)
if 'model_providers' in query_str:
# Return provider with unknown requester
fake_provider = persistence_model.ModelProvider(
uuid='provider-with-unknown-requester',
name='Unknown Requester Provider',
requester='unknown-requester-name',
base_url='https://unknown.com',
api_keys=[],
)
return _make_mock_result([_make_row_mock(fake_provider)])
elif 'llm_models' in query_str:
fake_model = persistence_model.LLMModel(
uuid='model-uuid',
name='Model',
provider_uuid='provider-with-unknown-requester',
abilities=[],
extra_args={},
)
return _make_mock_result([_make_row_mock(fake_model)])
return _make_mock_result([])
model_mgr.ap.persistence_mgr.execute_async = fake_execute
await model_mgr.initialize()
# Provider should be skipped
assert len(model_mgr.provider_dict) == 0
assert len(model_mgr.llm_models) == 0
model_mgr.ap.logger.warning.assert_called()
# ============================================================================
# Error Classes Tests
# ============================================================================
def test_requester_not_found_error_str():
"""Test RequesterNotFoundError string representation."""
error = provider_errors.RequesterNotFoundError('test-requester')
assert str(error) == 'Requester test-requester not found'
assert error.requester_name == 'test-requester'
def test_provider_not_found_error_str():
"""Test ProviderNotFoundError string representation."""
error = provider_errors.ProviderNotFoundError('test-provider')
assert str(error) == 'Provider test-provider not found'
assert error.provider_name == 'test-provider'
@@ -0,0 +1,633 @@
"""
Unit tests for ProviderAPIRequester base class and runtime entities in provider/modelmgr.
Tests requester initialization, configuration handling, token management,
and runtime model/provider behavior without calling real LLM APIs.
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.provider.modelmgr import requester
from langbot.pkg.provider.modelmgr import token
from langbot.pkg.entity.persistence import model as persistence_model
from langbot.pkg.provider.modelmgr.errors import RequesterError
# ============================================================================
# ProviderAPIRequester Base Class Tests
# ============================================================================
class TestableRequester(requester.ProviderAPIRequester):
"""Testable requester subclass for testing base class behavior."""
name = 'testable-requester'
default_config = {
'base_url': 'https://default.example.com',
'timeout': 60,
'max_retries': 3,
}
async def invoke_llm(
self,
query,
model: requester.RuntimeLLMModel,
messages: list,
funcs=None,
extra_args={},
remove_think=False,
):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Testable response')],
)
def test_requester_base_class_is_abstract():
"""Test ProviderAPIRequester cannot be instantiated directly."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
# ProviderAPIRequester has abstract methods, but ABCMeta allows instantiation
# if you don't call the abstract methods. Test that it has abstract methods.
assert hasattr(requester.ProviderAPIRequester, 'invoke_llm')
# Check that invoke_llm is abstract
assert hasattr(requester.ProviderAPIRequester.invoke_llm, '__isabstractmethod__')
def test_requester_default_config_merged():
"""Test requester merges default config with provided config."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {'base_url': 'https://custom.example.com', 'custom_key': 'custom_value'})
assert inst.requester_cfg['base_url'] == 'https://custom.example.com'
assert inst.requester_cfg['timeout'] == 60 # from default
assert inst.requester_cfg['max_retries'] == 3 # from default
assert inst.requester_cfg['custom_key'] == 'custom_value' # custom added
def test_requester_default_config_not_modified():
"""Test that default_config dict is not modified when merging."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {'base_url': 'https://override.example.com'})
assert TestableRequester.default_config['base_url'] == 'https://default.example.com'
assert inst.requester_cfg['base_url'] == 'https://override.example.com'
def test_requester_empty_config_uses_defaults():
"""Test requester uses defaults when empty config provided."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {})
assert inst.requester_cfg == inst.default_config
@pytest.mark.asyncio
async def test_requester_initialize_is_callable():
"""Test requester initialize method is callable (default is pass)."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {})
await inst.initialize()
# No exception should occur
@pytest.mark.asyncio
async def test_requester_scan_models_not_implemented():
"""Test scan_models raises NotImplementedError by default."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {})
await inst.initialize()
with pytest.raises(NotImplementedError) as exc_info:
await inst.scan_models()
assert 'does not support model scanning' in str(exc_info.value)
@pytest.mark.asyncio
async def test_requester_invoke_rerank_not_implemented():
"""Test invoke_rerank raises NotImplementedError by default."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {})
await inst.initialize()
# Create fake model
fake_provider_entity = persistence_model.ModelProvider(
uuid='provider-uuid',
name='Provider',
requester='test',
base_url='https://test.com',
api_keys=[],
)
fake_token_mgr = token.TokenManager(name='test', tokens=[])
fake_requester = inst
fake_provider = requester.RuntimeProvider(
provider_entity=fake_provider_entity,
token_mgr=fake_token_mgr,
requester=fake_requester,
)
fake_model_entity = persistence_model.RerankModel(
uuid='model-uuid',
name='Model',
provider_uuid='provider-uuid',
extra_args={},
)
fake_model = requester.RuntimeRerankModel(
model_entity=fake_model_entity,
provider=fake_provider,
)
with pytest.raises(NotImplementedError) as exc_info:
await inst.invoke_rerank(fake_model, 'query', ['doc1', 'doc2'])
assert 'does not support rerank' in str(exc_info.value)
# ============================================================================
# TokenManager Tests
# ============================================================================
def test_token_manager_initial_state():
"""Test TokenManager initial state."""
mgr = token.TokenManager(name='test-manager', tokens=['key1', 'key2', 'key3'])
assert mgr.name == 'test-manager'
assert mgr.tokens == ['key1', 'key2', 'key3']
assert mgr.using_token_index == 0
def test_token_manager_get_token():
"""Test TokenManager.get_token returns current token."""
mgr = token.TokenManager(name='test', tokens=['key1', 'key2'])
assert mgr.get_token() == 'key1'
def test_token_manager_get_token_empty():
"""Test TokenManager.get_token returns empty string when no tokens."""
mgr = token.TokenManager(name='test', tokens=[])
assert mgr.get_token() == ''
def test_token_manager_next_token_cycles():
"""Test TokenManager.next_token cycles through tokens."""
mgr = token.TokenManager(name='test', tokens=['key1', 'key2', 'key3'])
assert mgr.get_token() == 'key1'
mgr.next_token()
assert mgr.get_token() == 'key2'
mgr.next_token()
assert mgr.get_token() == 'key3'
# Should cycle back to first
mgr.next_token()
assert mgr.get_token() == 'key1'
def test_token_manager_next_token_single():
"""Test TokenManager.next_token with single token."""
mgr = token.TokenManager(name='test', tokens=['single-key'])
mgr.next_token()
assert mgr.get_token() == 'single-key'
mgr.next_token()
assert mgr.get_token() == 'single-key'
def test_token_manager_next_token_empty():
"""Test TokenManager.next_token with empty tokens doesn't error."""
mgr = token.TokenManager(name='test', tokens=[])
assert mgr.next_token() is None
assert mgr.get_token() == ''
# ============================================================================
# RuntimeProvider Tests
# ============================================================================
def test_runtime_provider_initialization(runtime_provider, fake_persistence_data):
"""Test RuntimeProvider initialization."""
provider = runtime_provider
provider_entity = fake_persistence_data['providers'][0]
assert provider.provider_entity.uuid == provider_entity.uuid
assert provider.provider_entity.name == provider_entity.name
assert provider.token_mgr.name == provider_entity.uuid
assert provider.token_mgr.tokens == provider_entity.api_keys
assert isinstance(provider.requester, requester.ProviderAPIRequester)
def test_runtime_provider_has_invoke_methods(runtime_provider):
"""Test RuntimeProvider has invoke methods that delegate to requester."""
provider = runtime_provider
assert hasattr(provider, 'invoke_llm')
assert hasattr(provider, 'invoke_llm_stream')
assert hasattr(provider, 'invoke_embedding')
assert hasattr(provider, 'invoke_rerank')
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_llm_model):
"""Test RuntimeProvider.invoke_llm delegates to requester."""
provider = runtime_provider
# Track that requester was called
provider.requester._invoke_count = 0
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# Create minimal query for testing (bypass validation)
query = pipeline_query.Query.model_construct(
query_id='test-query',
launcher_type='person',
launcher_id=12345,
sender_id=12345,
message_chain=None,
message_event=None,
adapter=None,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
result = await provider.invoke_llm(query, runtime_llm_model, messages)
assert provider.requester._invoke_count == 1
assert provider.requester._last_messages == messages
assert provider.requester._last_model == runtime_llm_model
assert result.role == 'assistant'
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider, runtime_llm_model):
"""Test RuntimeProvider.invoke_llm_stream yields chunks from requester."""
provider = runtime_provider
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
query = pipeline_query.Query.model_construct(
query_id='test-stream',
launcher_type='person',
launcher_id=12345,
sender_id=12345,
message_chain=None,
message_event=None,
adapter=None,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
chunks = []
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0].role == 'assistant'
@pytest.mark.asyncio
async def test_runtime_provider_invoke_embedding_returns_vectors(runtime_provider, runtime_embedding_model):
"""Test RuntimeProvider.invoke_embedding returns embedding vectors."""
provider = runtime_provider
result = await provider.invoke_embedding(runtime_embedding_model, ['text1', 'text2'])
assert len(result) == 2
assert result[0] == [0.1, 0.2, 0.3]
@pytest.mark.asyncio
async def test_runtime_provider_invoke_rerank_returns_scores(runtime_provider, runtime_rerank_model):
"""Test RuntimeProvider.invoke_rerank returns relevance scores."""
# Need to use the correct provider for rerank model
provider = runtime_rerank_model.provider
result = await provider.invoke_rerank(runtime_rerank_model, 'query', ['doc1', 'doc2', 'doc3'])
assert len(result) == 3
assert result[0]['index'] == 0
assert result[0]['relevance_score'] == 0.9
# ============================================================================
# RuntimeLLMModel Tests
# ============================================================================
def test_runtime_llm_model_initialization(runtime_llm_model, fake_persistence_data):
"""Test RuntimeLLMModel initialization."""
model = runtime_llm_model
model_entity = fake_persistence_data['llm_models'][0]
assert model.model_entity.uuid == model_entity.uuid
assert model.model_entity.name == model_entity.name
assert model.model_entity.abilities == model_entity.abilities
assert model.model_entity.extra_args == model_entity.extra_args
assert model.provider is not None
def test_runtime_llm_model_provider_ref(runtime_llm_model):
"""Test RuntimeLLMModel has correct provider reference."""
model = runtime_llm_model
assert model.provider.provider_entity is not None
assert model.provider.token_mgr is not None
assert model.provider.requester is not None
# ============================================================================
# RuntimeEmbeddingModel Tests
# ============================================================================
def test_runtime_embedding_model_initialization(runtime_embedding_model, fake_persistence_data):
"""Test RuntimeEmbeddingModel initialization."""
model = runtime_embedding_model
model_entity = fake_persistence_data['embedding_models'][0]
assert model.model_entity.uuid == model_entity.uuid
assert model.model_entity.name == model_entity.name
assert model.model_entity.extra_args == model_entity.extra_args
assert model.provider is not None
# ============================================================================
# RuntimeRerankModel Tests
# ============================================================================
def test_runtime_rerank_model_initialization(runtime_rerank_model, fake_persistence_data):
"""Test RuntimeRerankModel initialization."""
model = runtime_rerank_model
model_entity = fake_persistence_data['rerank_models'][0]
assert model.model_entity.uuid == model_entity.uuid
assert model.model_entity.name == model_entity.name
assert model.model_entity.extra_args == model_entity.extra_args
assert model.provider is not None
# ============================================================================
# RequesterError Tests
# ============================================================================
def test_requester_error_message_format():
"""Test RequesterError message format."""
error = RequesterError('API returned 500')
assert '模型请求失败' in str(error)
assert 'API returned 500' in str(error)
def test_requester_error_is_exception():
"""Test RequesterError is Exception subclass."""
error = RequesterError('test')
assert isinstance(error, Exception)
# ============================================================================
# ProviderAPIRequester Config Validation Tests
# ============================================================================
def test_requester_with_missing_base_url():
"""Test requester handles missing base_url in config."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
# If base_url is in default_config, it will be used
inst = TestableRequester(mock_app, {'timeout': 30})
assert inst.requester_cfg['base_url'] == 'https://default.example.com'
def test_requester_with_none_values():
"""Test requester handles None values in config."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = TestableRequester(mock_app, {'timeout': None, 'base_url': 'https://test.com'})
# None values are kept in the merged config
assert inst.requester_cfg['timeout'] is None
class RequesterWithNoDefaults(requester.ProviderAPIRequester):
"""Requester with empty defaults for testing."""
name = 'no-defaults-requester'
default_config = {}
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
pass
def test_requester_empty_defaults_with_empty_config():
"""Test requester with empty defaults and empty config."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = RequesterWithNoDefaults(mock_app, {})
assert inst.requester_cfg == {}
def test_requester_empty_defaults_with_values():
"""Test requester with empty defaults receives config values."""
mock_app = SimpleNamespace()
mock_app.logger = Mock()
inst = RequesterWithNoDefaults(mock_app, {'base_url': 'https://custom.com', 'api_key': 'key'})
assert inst.requester_cfg['base_url'] == 'https://custom.com'
assert inst.requester_cfg['api_key'] == 'key'
# ============================================================================
# RuntimeProvider Error Handling Tests
# ============================================================================
class ErrorThrowingRequester(requester.ProviderAPIRequester):
"""Requester that throws errors for testing."""
name = 'error-requester'
default_config = {}
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
raise RequesterError('Simulated API error')
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmgr):
"""Test RuntimeProvider.invoke_llm propagates requester errors."""
mock_app = mock_app_for_modelmgr
# Add monitoring_service for error handling path
mock_app.monitoring_service = AsyncMock()
requester_inst = ErrorThrowingRequester(mock_app, {})
await requester_inst.initialize()
provider_entity = persistence_model.ModelProvider(
uuid='error-provider',
name='Error Provider',
requester='error-requester',
base_url='https://error.com',
api_keys=['error-key'],
)
token_mgr = token.TokenManager(name='error-provider', tokens=['error-key'])
provider = requester.RuntimeProvider(
provider_entity=provider_entity,
token_mgr=token_mgr,
requester=requester_inst,
)
model_entity = persistence_model.LLMModel(
uuid='error-model',
name='Error Model',
provider_uuid='error-provider',
abilities=[],
extra_args={},
)
model = requester.RuntimeLLMModel(model_entity=model_entity, provider=provider)
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
query = pipeline_query.Query.model_construct(
query_id='error-query',
launcher_type='person',
launcher_id=12345,
sender_id=12345,
message_chain=None,
message_event=None,
adapter=None,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
with pytest.raises(RequesterError):
await provider.invoke_llm(query, model, messages)
# ============================================================================
# LLMModelInfo Tests (from entities.py)
# ============================================================================
def test_llm_model_info_basic():
"""Test LLMModelInfo basic structure."""
from langbot.pkg.provider.modelmgr.entities import LLMModelInfo
mock_app = SimpleNamespace()
mock_app.logger = Mock()
fake_requester = TestableRequester(mock_app, {})
fake_token_mgr = token.TokenManager(name='test', tokens=['key'])
info = LLMModelInfo(
name='test-model',
model_name='gpt-4',
token_mgr=fake_token_mgr,
requester=fake_requester,
tool_call_supported=True,
vision_supported=False,
)
assert info.name == 'test-model'
assert info.model_name == 'gpt-4'
assert info.tool_call_supported == True
assert info.vision_supported == False
def test_llm_model_info_optional_fields():
"""Test LLMModelInfo optional fields default values."""
from langbot.pkg.provider.modelmgr.entities import LLMModelInfo
mock_app = SimpleNamespace()
mock_app.logger = Mock()
fake_requester = TestableRequester(mock_app, {})
fake_token_mgr = token.TokenManager(name='test', tokens=['key'])
info = LLMModelInfo(
name='minimal-model',
token_mgr=fake_token_mgr,
requester=fake_requester,
)
assert info.model_name is None
assert info.tool_call_supported == False # default
assert info.vision_supported == False # default
@@ -0,0 +1,321 @@
"""Unit tests for SessionManager.
Tests cover:
- Session creation and retrieval
- Conversation creation with prompts
- Session concurrency semaphore
"""
from __future__ import annotations
import pytest
import asyncio
from unittest.mock import Mock
from importlib import import_module
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
def get_session_module():
"""Lazy import to avoid circular import issues."""
return import_module('langbot.pkg.provider.session.sessionmgr')
class TestSessionManagerInit:
"""Tests for SessionManager initialization."""
def test_init_stores_app_reference(self):
"""Test that __init__ stores the Application reference."""
sessionmgr = get_session_module()
mock_app = Mock()
manager = sessionmgr.SessionManager(mock_app)
assert manager.ap is mock_app
def test_init_empty_session_list(self):
"""Test that session_list starts empty."""
sessionmgr = get_session_module()
mock_app = Mock()
manager = sessionmgr.SessionManager(mock_app)
assert manager.session_list == []
@pytest.mark.asyncio
async def test_initialize_empty(self):
"""Test that initialize does nothing (current implementation)."""
sessionmgr = get_session_module()
mock_app = Mock()
manager = sessionmgr.SessionManager(mock_app)
await manager.initialize()
# Should not raise or change state
assert manager.session_list == []
class TestSessionManagerGetSession:
"""Tests for get_session method."""
@pytest.fixture
def mock_app_with_config(self):
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
return mock_app
@pytest.fixture
def sample_query(self):
"""Create sample query for testing."""
query = Mock(spec=pipeline_query.Query)
query.launcher_type = provider_session.LauncherTypes.PERSON
query.launcher_id = '12345'
query.sender_id = '12345'
return query
@pytest.mark.asyncio
async def test_creates_new_session_when_not_found(self, mock_app_with_config, sample_query):
"""Test that get_session creates new session when not found."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
session = await manager.get_session(sample_query)
assert session is not None
assert session.launcher_type == sample_query.launcher_type
assert session.launcher_id == sample_query.launcher_id
assert session.sender_id == sample_query.sender_id
assert len(manager.session_list) == 1
@pytest.mark.asyncio
async def test_returns_existing_session_when_found(self, mock_app_with_config, sample_query):
"""Test that get_session returns existing session when found."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
# First call creates session
session1 = await manager.get_session(sample_query)
# Second call should return same session
session2 = await manager.get_session(sample_query)
assert session1 is session2
assert len(manager.session_list) == 1
@pytest.mark.asyncio
async def test_session_has_semaphore(self, mock_app_with_config, sample_query):
"""Test that created session has semaphore for concurrency."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
session = await manager.get_session(sample_query)
assert hasattr(session, '_semaphore')
assert session._semaphore is not None
assert isinstance(session._semaphore, asyncio.Semaphore)
@pytest.mark.asyncio
async def test_different_launchers_have_different_sessions(self, mock_app_with_config):
"""Test that different launcher_id creates different sessions."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
query1 = Mock(spec=pipeline_query.Query)
query1.launcher_type = provider_session.LauncherTypes.PERSON
query1.launcher_id = 'user1'
query1.sender_id = 'user1'
query2 = Mock(spec=pipeline_query.Query)
query2.launcher_type = provider_session.LauncherTypes.PERSON
query2.launcher_id = 'user2'
query2.sender_id = 'user2'
session1 = await manager.get_session(query1)
session2 = await manager.get_session(query2)
assert session1 is not session2
assert len(manager.session_list) == 2
@pytest.mark.asyncio
async def test_different_launcher_types_have_different_sessions(self, mock_app_with_config):
"""Test that different launcher_type creates different sessions."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
query1 = Mock(spec=pipeline_query.Query)
query1.launcher_type = provider_session.LauncherTypes.PERSON
query1.launcher_id = 'same_id'
query1.sender_id = 'same_id'
query2 = Mock(spec=pipeline_query.Query)
query2.launcher_type = provider_session.LauncherTypes.GROUP
query2.launcher_id = 'same_id'
query2.sender_id = 'same_id'
session1 = await manager.get_session(query1)
session2 = await manager.get_session(query2)
assert session1 is not session2
assert len(manager.session_list) == 2
class TestSessionManagerGetConversation:
"""Tests for get_conversation method."""
@pytest.fixture
def mock_app_with_config(self):
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
return mock_app
@pytest.fixture
def sample_session(self):
"""Create sample session for testing."""
session = Mock(spec=provider_session.Session)
session.launcher_type = provider_session.LauncherTypes.PERSON
session.launcher_id = '12345'
session.sender_id = '12345'
session.conversations = []
session.using_conversation = None
return session
@pytest.fixture
def sample_query(self):
"""Create sample query for testing."""
query = Mock(spec=pipeline_query.Query)
query.launcher_type = provider_session.LauncherTypes.PERSON
query.launcher_id = '12345'
query.sender_id = '12345'
return query
@pytest.mark.asyncio
async def test_creates_conversation_with_prompt(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that get_conversation creates conversation with prompt."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
assert conversation is not None
assert conversation.pipeline_uuid == pipeline_uuid
assert conversation.bot_uuid == bot_uuid
assert conversation.prompt is not None
assert len(sample_session.conversations) == 1
@pytest.mark.asyncio
async def test_uses_existing_conversation_when_pipeline_matches(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that get_conversation uses existing conversation when pipeline matches."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
# First call creates conversation
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
# Second call with same pipeline should return same conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
assert conv1 is conv2
assert len(sample_session.conversations) == 1
@pytest.mark.asyncio
async def test_creates_new_conversation_when_pipeline_changes(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that get_conversation creates new conversation when pipeline changes."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
# First call with pipeline1
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
)
# Second call with different pipeline should create new conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
)
assert conv1 is not conv2
assert len(sample_session.conversations) == 2
assert sample_session.using_conversation is conv2
@pytest.mark.asyncio
async def test_conversation_has_empty_messages(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that created conversation has empty messages list."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
)
assert conversation.messages == []
@pytest.mark.asyncio
async def test_prompt_messages_from_config(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that prompt messages are created from prompt_config."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'System message'},
{'role': 'user', 'content': 'User message'}
]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
)
assert conversation.prompt.name == 'default'
assert len(conversation.prompt.messages) == 2
@@ -0,0 +1,336 @@
"""Unit tests for ToolManager.
Tests cover:
- Tool schema generation for OpenAI and Anthropic
- Tool execution dispatch
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
def get_toolmgr_module():
"""Lazy import to avoid circular import issues."""
return import_module('langbot.pkg.provider.tools.toolmgr')
class TestToolManagerInit:
"""Tests for ToolManager initialization."""
def test_init_stores_app_reference(self):
"""Test that __init__ stores the Application reference."""
toolmgr = get_toolmgr_module()
mock_app = Mock()
manager = toolmgr.ToolManager(mock_app)
assert manager.ap is mock_app
def test_init_no_tool_loaders(self):
"""Test that tool loaders are not initialized before initialize()."""
toolmgr = get_toolmgr_module()
mock_app = Mock()
manager = toolmgr.ToolManager(mock_app)
assert hasattr(manager, 'plugin_tool_loader') is False or manager.plugin_tool_loader is None
class TestToolManagerSchemaGeneration:
"""Tests for tool schema generation methods."""
@pytest.fixture
def mock_app(self):
"""Create mock app."""
mock_app = Mock()
mock_app.logger = Mock()
return mock_app
@pytest.fixture
def sample_tools(self):
"""Create sample LLMTool list for testing."""
def dummy_weather_func(**kwargs):
return "weather result"
def dummy_calc_func(**kwargs):
return "calc result"
tools = [
resource_tool.LLMTool(
name='get_weather',
human_desc='Get current weather for a location',
description='Get current weather for a location',
parameters={
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'City name'
}
},
'required': ['location']
},
func=dummy_weather_func
),
resource_tool.LLMTool(
name='calculate',
human_desc='Perform a calculation',
description='Perform a calculation',
parameters={
'type': 'object',
'properties': {
'expression': {
'type': 'string',
'description': 'Math expression'
}
},
'required': ['expression']
},
func=dummy_calc_func
),
]
return tools
@pytest.mark.asyncio
async def test_generate_tools_for_openai(self, mock_app, sample_tools):
"""Test that generate_tools_for_openai produces correct schema."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_openai(sample_tools)
assert len(result) == 2
# Verify first tool schema
tool1 = result[0]
assert tool1['type'] == 'function'
assert tool1['function']['name'] == 'get_weather'
assert tool1['function']['description'] == 'Get current weather for a location'
assert 'parameters' in tool1['function']
assert tool1['function']['parameters']['type'] == 'object'
# Verify second tool schema
tool2 = result[1]
assert tool2['type'] == 'function'
assert tool2['function']['name'] == 'calculate'
@pytest.mark.asyncio
async def test_generate_tools_for_anthropic(self, mock_app, sample_tools):
"""Test that generate_tools_for_anthropic produces correct schema."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_anthropic(sample_tools)
assert len(result) == 2
# Verify first tool schema (Anthropic format)
tool1 = result[0]
assert tool1['name'] == 'get_weather'
assert tool1['description'] == 'Get current weather for a location'
assert 'input_schema' in tool1
assert tool1['input_schema']['type'] == 'object'
# Verify second tool schema
tool2 = result[1]
assert tool2['name'] == 'calculate'
assert 'input_schema' in tool2
@pytest.mark.asyncio
async def test_generate_tools_empty_list(self, mock_app):
"""Test that generating tools from empty list returns empty list."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
openai_result = await manager.generate_tools_for_openai([])
assert openai_result == []
anthropic_result = await manager.generate_tools_for_anthropic([])
assert anthropic_result == []
@pytest.mark.asyncio
async def test_openai_schema_fields_complete(self, mock_app, sample_tools):
"""Test that OpenAI schema includes all required fields."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_openai(sample_tools)
for tool_schema in result:
assert 'type' in tool_schema
assert tool_schema['type'] == 'function'
assert 'function' in tool_schema
func = tool_schema['function']
assert 'name' in func
assert 'description' in func
assert 'parameters' in func
@pytest.mark.asyncio
async def test_anthropic_schema_fields_complete(self, mock_app, sample_tools):
"""Test that Anthropic schema includes all required fields."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_anthropic(sample_tools)
for tool_schema in result:
assert 'name' in tool_schema
assert 'description' in tool_schema
assert 'input_schema' in tool_schema
class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method."""
@pytest.fixture
def mock_app_with_loaders(self):
"""Create mock app with mock tool loaders."""
mock_app = Mock()
mock_app.logger = Mock()
# Create mock plugin loader
mock_plugin_loader = Mock()
mock_plugin_loader.has_tool = AsyncMock(return_value=False)
mock_plugin_loader.invoke_tool = AsyncMock(return_value='plugin_result')
mock_plugin_loader.initialize = AsyncMock()
mock_plugin_loader.shutdown = AsyncMock()
# Create mock MCP loader
mock_mcp_loader = Mock()
mock_mcp_loader.has_tool = AsyncMock(return_value=False)
mock_mcp_loader.invoke_tool = AsyncMock(return_value='mcp_result')
mock_mcp_loader.initialize = AsyncMock()
mock_mcp_loader.shutdown = AsyncMock()
return mock_app, mock_plugin_loader, mock_mcp_loader
@pytest.fixture
def sample_query(self):
"""Create sample query for testing."""
query = Mock(spec=pipeline_query.Query)
return query
@pytest.mark.asyncio
async def test_execute_calls_plugin_loader_when_has_tool(
self, mock_app_with_loaders, sample_query
):
"""Test that execute_func_call uses plugin loader when tool exists there."""
toolmgr = get_toolmgr_module()
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
mock_plugin_loader.has_tool = AsyncMock(return_value=True)
manager = toolmgr.ToolManager(mock_app)
manager.plugin_tool_loader = mock_plugin_loader
manager.mcp_tool_loader = mock_mcp_loader
result = await manager.execute_func_call(
'test_tool',
{'param': 'value'},
sample_query
)
assert result == 'plugin_result'
mock_plugin_loader.invoke_tool.assert_called_once_with(
'test_tool', {'param': 'value'}, sample_query
)
# MCP loader should not be called
mock_mcp_loader.invoke_tool.assert_not_called()
@pytest.mark.asyncio
async def test_execute_calls_mcp_loader_when_plugin_not_found(
self, mock_app_with_loaders, sample_query
):
"""Test that execute_func_call uses MCP loader when plugin doesn't have tool."""
toolmgr = get_toolmgr_module()
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
mock_plugin_loader.has_tool = AsyncMock(return_value=False)
mock_mcp_loader.has_tool = AsyncMock(return_value=True)
manager = toolmgr.ToolManager(mock_app)
manager.plugin_tool_loader = mock_plugin_loader
manager.mcp_tool_loader = mock_mcp_loader
result = await manager.execute_func_call(
'test_tool',
{'param': 'value'},
sample_query
)
assert result == 'mcp_result'
mock_mcp_loader.invoke_tool.assert_called_once_with(
'test_tool', {'param': 'value'}, sample_query
)
@pytest.mark.asyncio
async def test_execute_raises_when_tool_not_found(
self, mock_app_with_loaders, sample_query
):
"""Test that execute_func_call raises ValueError when tool not found."""
toolmgr = get_toolmgr_module()
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
mock_plugin_loader.has_tool = AsyncMock(return_value=False)
mock_mcp_loader.has_tool = AsyncMock(return_value=False)
manager = toolmgr.ToolManager(mock_app)
manager.plugin_tool_loader = mock_plugin_loader
manager.mcp_tool_loader = mock_mcp_loader
with pytest.raises(ValueError, match='未找到工具'):
await manager.execute_func_call(
'unknown_tool',
{},
sample_query
)
@pytest.mark.asyncio
async def test_plugin_loader_checked_first(
self, mock_app_with_loaders, sample_query
):
"""Test that plugin loader is checked before MCP loader."""
toolmgr = get_toolmgr_module()
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
# Both loaders have the tool, but plugin should be used
mock_plugin_loader.has_tool = AsyncMock(return_value=True)
mock_mcp_loader.has_tool = AsyncMock(return_value=True)
manager = toolmgr.ToolManager(mock_app)
manager.plugin_tool_loader = mock_plugin_loader
manager.mcp_tool_loader = mock_mcp_loader
await manager.execute_func_call('test_tool', {}, sample_query)
# Plugin loader should be invoked, MCP should not
mock_plugin_loader.invoke_tool.assert_called_once()
mock_mcp_loader.invoke_tool.assert_not_called()
class TestToolManagerShutdown:
"""Tests for shutdown method."""
@pytest.mark.asyncio
async def test_shutdown_calls_loader_shutdown(self):
"""Test that shutdown calls shutdown on both loaders."""
toolmgr = get_toolmgr_module()
mock_app = Mock()
mock_plugin_loader = Mock()
mock_plugin_loader.shutdown = AsyncMock()
mock_mcp_loader = Mock()
mock_mcp_loader.shutdown = AsyncMock()
manager = toolmgr.ToolManager(mock_app)
manager.plugin_tool_loader = mock_plugin_loader
manager.mcp_tool_loader = mock_mcp_loader
await manager.shutdown()
mock_plugin_loader.shutdown.assert_called_once()
mock_mcp_loader.shutdown.assert_called_once()