Compare commits

..

1 Commits

Author SHA1 Message Date
huanghuoguoguo
e5d71597f1 fix(survey): prevent option controls from submitting forms 2026-06-14 21:17:18 +08:00
107 changed files with 1757 additions and 3004 deletions

View File

@@ -1,46 +0,0 @@
name: Frontend Tests
on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
push:
branches:
- master
- develop
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
jobs:
playwright-smoke:
name: Playwright Smoke
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 8.9.2
- name: Install dependencies
working-directory: web
run: pnpm install --frozen-lockfile
- name: Install Playwright browsers
working-directory: web
run: pnpm exec playwright install --with-deps chromium
- name: Run Playwright smoke tests
working-directory: web
run: pnpm test:e2e

View File

@@ -29,7 +29,7 @@ jobs:
run: uv sync --dev run: uv sync --dev
- name: Run ruff check - name: Run ruff check
run: uv run ruff check src/langbot/ tests/ --output-format=concise run: uv run ruff check src
- name: Run ruff format - name: Run ruff format
run: uv run ruff format src --check run: uv run ruff format src --check

View File

@@ -84,67 +84,6 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
e2e:
name: E2E Startup Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run E2E startup tests
run: uv run pytest tests/e2e -q --tb=short
- name: E2E Test Summary
if: always()
run: |
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
box-integration:
name: Box Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Check Docker runtime
run: docker info
- name: Run Box integration tests
run: uv run pytest tests/integration_tests -q --tb=short
- name: Box Integration Test Summary
if: always()
run: |
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage: coverage:
name: Coverage Gate name: Coverage Gate
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -190,4 +129,4 @@ jobs:
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

View File

@@ -1,7 +1,6 @@
# LangBot Test Suite # LangBot Test Suite
This directory contains the LangBot backend test suite, including unit tests, This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
integration tests, startup E2E tests, and container-backed Box runtime tests.
## Quality Gate Layers ## Quality Gate Layers
@@ -11,15 +10,10 @@ LangBot uses a layered quality gate system for developers and CI:
|-------|---------|--------------|-------------| |-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit | | **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly | | **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI | | **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes | | **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default **Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
gates. They run in separate CI workflows. Frontend Playwright tests live under
`web/tests/e2e` and are documented in `web/README.md`.
### Developer Workflow ### Developer Workflow
@@ -34,9 +28,6 @@ make test-all-local
bash scripts/test-quick.sh # ~2 min bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min bash scripts/test-coverage.sh # ~8 min
uv run --python 3.12 pytest tests/e2e -q --tb=short
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
cd web && pnpm test:e2e
``` ```
### Coverage Baseline ### Coverage Baseline
@@ -79,12 +70,6 @@ tests/
│ └── persistence/ # Database/persistence tests │ └── persistence/ # Database/persistence tests
│ ├── __init__.py │ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests │ └── test_migrations.py # Alembic migration tests
├── e2e/ # Real LangBot startup E2E tests
│ ├── conftest.py
│ ├── test_startup.py
│ └── utils/
├── integration_tests/ # Container-backed integration tests
│ └── box/ # Box runtime and MCP process tests
├── smoke/ # Smoke tests (quick validation) ├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py │ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests ├── unit_tests/ # Unit tests
@@ -318,44 +303,6 @@ These tests:
- Test prevent_default, exception handling, and full message flow - Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys - Do not require real LLM provider keys
### Running backend E2E startup tests
Backend E2E tests start a real LangBot process with a generated minimal
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
They do not require provider keys or external services.
```bash
uv run --python 3.12 pytest tests/e2e -q --tb=short
```
These tests verify startup orchestration, migrations, API route registration,
and the minimal no-LLM startup path. The E2E process manager disables ambient
proxy variables for subprocess startup and uses direct localhost HTTP clients,
so local proxy settings should not affect the health checks.
### Running Box integration tests
Box integration tests exercise the real sandbox runtime path, including command
execution, session persistence, managed process WebSocket attachment, and
cleanup behavior.
```bash
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
```
These tests require a working Docker or Podman runtime. In CI, the dedicated
Box integration job checks Docker availability before running the tests.
### Running frontend E2E tests
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
and mock the LangBot backend and Space APIs, so no backend process is required.
```bash
cd web
pnpm test:e2e
```
### Known Issues ### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues: Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
@@ -373,9 +320,6 @@ Tests are automatically run on:
- Push to master/develop branches - Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility. The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
they exercise process/container behavior instead of pure Python compatibility.
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
## Adding New Tests ## Adding New Tests
@@ -462,4 +406,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
- [ ] Add E2E tests - [ ] Add E2E tests
- [ ] Add performance benchmarks - [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality - [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis - [ ] Add property-based testing with Hypothesis

View File

@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
base_url = f'http://127.0.0.1:{e2e_port}' base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client: with httpx.Client(base_url=base_url, timeout=10.0) as client:
yield client yield client
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir): def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file.""" """Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db' return e2e_tmpdir / 'data' / 'langbot.db'

View File

@@ -38,13 +38,12 @@ class TestStartupFlow:
# System info should contain version info # System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data'] assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, langbot_process, e2e_db_path): def test_database_initialized(self, e2e_db_path):
"""Verify SQLite database was created and initialized.""" """Verify SQLite database was created and initialized."""
assert e2e_db_path.exists() assert e2e_db_path.exists()
# Database should have some tables after migration # Database should have some tables after migration
import sqlite3 import sqlite3
conn = sqlite3.connect(str(e2e_db_path)) conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor() cursor = conn.cursor()
@@ -75,13 +74,10 @@ class TestStartupFlow:
def test_auth_endpoint(self, e2e_client, e2e_tmpdir): def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint.""" """Test auth endpoint."""
# First startup may allow initial setup # First startup may allow initial setup
response = e2e_client.post( response = e2e_client.post('/api/v1/user/auth', json={
'/api/v1/user/auth', 'username': 'admin',
json={ 'password': 'admin',
'user': 'admin', })
'password': 'admin',
},
)
# Response could be: # Response could be:
# - 200 if auth succeeds # - 200 if auth succeeds
@@ -98,10 +94,9 @@ class TestStartupStages:
# If API responds on e2e_port, config was loaded # If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200 assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, langbot_process, e2e_db_path): def test_migrations_applied(self, e2e_db_path):
"""Verify database migrations were applied.""" """Verify database migrations were applied."""
import sqlite3 import sqlite3
conn = sqlite3.connect(str(e2e_db_path)) conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor() cursor = conn.cursor()

View File

@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
for path in directories.values(): for path in directories.values():
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return directories return directories

View File

@@ -44,17 +44,6 @@ class LangBotProcess:
# Prepare environment # Prepare environment
env = os.environ.copy() env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src') env['PYTHONPATH'] = str(self.project_root / 'src')
for proxy_key in (
'HTTP_PROXY',
'HTTPS_PROXY',
'ALL_PROXY',
'http_proxy',
'https_proxy',
'all_proxy',
):
env.pop(proxy_key, None)
env['NO_PROXY'] = '127.0.0.1,localhost'
env['no_proxy'] = '127.0.0.1,localhost'
# Set API port via environment variable # Set API port via environment variable
env['API__PORT'] = str(self.port) env['API__PORT'] = str(self.port)
@@ -90,11 +79,9 @@ precision = 2
f.write(coveragerc_content) f.write(coveragerc_content)
cmd = [ cmd = [
'coverage', 'coverage', 'run',
'run',
'--rcfile=' + str(coveragerc_path), '--rcfile=' + str(coveragerc_path),
'-m', '-m', 'langbot',
'langbot',
] ]
else: else:
cmd = ['uv', 'run', 'python', '-m', 'langbot'] cmd = ['uv', 'run', 'python', '-m', 'langbot']
@@ -126,8 +113,6 @@ precision = 2
r = httpx.get( r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info', f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0, timeout=2.0,
follow_redirects=False,
trust_env=False,
) )
if r.status_code == 200: if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}') logger.info(f'LangBot started successfully on port {self.port}')
@@ -200,8 +185,6 @@ precision = 2
r = httpx.get( r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info', f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0, timeout=5.0,
follow_redirects=False,
trust_env=False,
) )
return r.status_code == 200 return r.status_code == 200
except Exception: except Exception:
@@ -218,4 +201,4 @@ def find_project_root() -> Path:
return parent return parent
# Fallback to LangBot-test-build directory # Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build') return Path('/home/glwuy/langbot-app/LangBot-test-build')

View File

@@ -58,45 +58,45 @@ from tests.factories.platform import (
__all__ = [ __all__ = [
# App # App
'FakeApp', "FakeApp",
'fake_app', "fake_app",
# Message chains # Message chains
'text_chain', "text_chain",
'group_text_chain', "group_text_chain",
'mention_chain', "mention_chain",
'image_chain', "image_chain",
# Message events # Message events
'friend_message_event', "friend_message_event",
'group_message_event', "group_message_event",
# Mock adapters # Mock adapters
'mock_adapter', "mock_adapter",
# Queries # Queries
'text_query', "text_query",
'group_text_query', "group_text_query",
'private_text_query', "private_text_query",
'command_query', "command_query",
'mention_query', "mention_query",
'empty_query', "empty_query",
'image_query', "image_query",
'file_query', "file_query",
'unsupported_query', "unsupported_query",
'voice_query', "voice_query",
'at_all_query', "at_all_query",
'query_with_session', "query_with_session",
'query_with_config', "query_with_config",
# Provider # Provider
'FakeProvider', "FakeProvider",
'fake_provider', "fake_provider",
'fake_provider_pong', "fake_provider_pong",
'fake_provider_timeout', "fake_provider_timeout",
'fake_provider_auth_error', "fake_provider_auth_error",
'fake_provider_rate_limit', "fake_provider_rate_limit",
'fake_provider_malformed', "fake_provider_malformed",
'fake_model', "fake_model",
# Platform # Platform
'FakePlatform', "FakePlatform",
'fake_platform', "fake_platform",
'fake_platform_with_streaming', "fake_platform_with_streaming",
'fake_platform_with_failure', "fake_platform_with_failure",
'mock_platform_adapter', "mock_platform_adapter",
] ]

View File

@@ -30,36 +30,32 @@ def _next_query_id() -> int:
# ============== Message Chain Factories ============== # ============== Message Chain Factories ==============
def text_chain(text: str = 'hello') -> platform_message.MessageChain: def text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a simple text message chain.""" """Create a simple text message chain."""
return platform_message.MessageChain( return platform_message.MessageChain([
[ platform_message.Plain(text=text),
platform_message.Plain(text=text), ])
]
)
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain: def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event).""" """Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text) return text_chain(text)
def mention_chain( def mention_chain(
text: str = 'hello', text: str = "hello",
target: typing.Union[int, str] = 12345, target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""Create a message chain with @mention.""" """Create a message chain with @mention."""
return platform_message.MessageChain( return platform_message.MessageChain([
[ platform_message.At(target=target),
platform_message.At(target=target), platform_message.Plain(text=f" {text}"),
platform_message.Plain(text=f' {text}'), ])
]
)
def image_chain( def image_chain(
text: str = '', text: str = "",
url: str = 'https://example.com/image.png', url: str = "https://example.com/image.png",
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""Create a message chain with an image.""" """Create a message chain with an image."""
components = [] components = []
@@ -70,15 +66,13 @@ def image_chain(
def command_chain( def command_chain(
command: str = 'help', command: str = "help",
prefix: str = '/', prefix: str = "/",
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""Create a command message chain.""" """Create a command message chain."""
return platform_message.MessageChain( return platform_message.MessageChain([
[ platform_message.Plain(text=f"{prefix}{command}"),
platform_message.Plain(text=f'{prefix}{command}'), ])
]
)
# ============== Message Event Factories ============== # ============== Message Event Factories ==============
@@ -87,7 +81,7 @@ def command_chain(
def friend_message_event( def friend_message_event(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
nickname: str = 'TestUser', nickname: str = "TestUser",
) -> platform_events.FriendMessage: ) -> platform_events.FriendMessage:
"""Create a friend (private) message event.""" """Create a friend (private) message event."""
sender = platform_entities.Friend( sender = platform_entities.Friend(
@@ -96,7 +90,7 @@ def friend_message_event(
remark=None, remark=None,
) )
return platform_events.FriendMessage( return platform_events.FriendMessage(
type='FriendMessage', type="FriendMessage",
sender=sender, sender=sender,
message_chain=message_chain, message_chain=message_chain,
time=1609459200, time=1609459200,
@@ -106,9 +100,9 @@ def friend_message_event(
def group_message_event( def group_message_event(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
sender_name: str = 'TestUser', sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
group_name: str = 'TestGroup', group_name: str = "TestGroup",
) -> platform_events.GroupMessage: ) -> platform_events.GroupMessage:
"""Create a group message event.""" """Create a group message event."""
group = platform_entities.Group( group = platform_entities.Group(
@@ -123,7 +117,7 @@ def group_message_event(
group=group, group=group,
) )
return platform_events.GroupMessage( return platform_events.GroupMessage(
type='GroupMessage', type="GroupMessage",
sender=sender, sender=sender,
message_chain=message_chain, message_chain=message_chain,
time=1609459200, time=1609459200,
@@ -158,36 +152,36 @@ def _base_query(
query_id = _next_query_id() query_id = _next_query_id()
base_data = { base_data = {
'query_id': query_id, "query_id": query_id,
'launcher_type': launcher_type, "launcher_type": launcher_type,
'launcher_id': launcher_id, "launcher_id": launcher_id,
'sender_id': sender_id, "sender_id": sender_id,
'message_chain': message_chain, "message_chain": message_chain,
'message_event': message_event, "message_event": message_event,
'adapter': adapter, "adapter": adapter,
'pipeline_uuid': 'test-pipeline-uuid', "pipeline_uuid": "test-pipeline-uuid",
'bot_uuid': 'test-bot-uuid', "bot_uuid": "test-bot-uuid",
'pipeline_config': { "pipeline_config": {
'ai': { "ai": {
'runner': {'runner': 'local-agent'}, "runner": {"runner": "local-agent"},
'local-agent': { "local-agent": {
'model': {'primary': 'test-model-uuid', 'fallbacks': []}, "model": {"primary": "test-model-uuid", "fallbacks": []},
'prompt': 'test-prompt', "prompt": "test-prompt",
}, },
}, },
'output': {'misc': {'at-sender': False, 'quote-origin': False}}, "output": {"misc": {"at-sender": False, "quote-origin": False}},
'trigger': {'misc': {'combine-quote-message': False}}, "trigger": {"misc": {"combine-quote-message": False}},
}, },
'session': None, "session": None,
'prompt': None, "prompt": None,
'messages': [], "messages": [],
'user_message': None, "user_message": None,
'use_funcs': [], "use_funcs": [],
'use_llm_model_uuid': None, "use_llm_model_uuid": None,
'variables': {}, "variables": {},
'resp_messages': [], "resp_messages": [],
'resp_message_chain': None, "resp_message_chain": None,
'current_stage_name': None, "current_stage_name": None,
} }
# Apply overrides # Apply overrides
@@ -198,7 +192,7 @@ def _base_query(
def text_query( def text_query(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -218,7 +212,7 @@ def text_query(
def private_text_query( def private_text_query(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -227,7 +221,7 @@ def private_text_query(
def group_text_query( def group_text_query(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
**overrides, **overrides,
@@ -248,8 +242,8 @@ def group_text_query(
def command_query( def command_query(
command: str = 'help', command: str = "help",
prefix: str = '/', prefix: str = "/",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -269,7 +263,7 @@ def command_query(
def mention_query( def mention_query(
text: str = 'hello', text: str = "hello",
target: typing.Union[int, str] = 12345, target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
@@ -307,8 +301,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
def image_query( def image_query(
text: str = '', text: str = "",
url: str = 'https://example.com/image.png', url: str = "https://example.com/image.png",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -328,9 +322,9 @@ def image_query(
def file_query( def file_query(
url: str = 'https://example.com/document.pdf', url: str = "https://example.com/document.pdf",
name: str = 'document.pdf', name: str = "document.pdf",
text: str = '', text: str = "",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -354,8 +348,8 @@ def file_query(
def unsupported_query( def unsupported_query(
unsupported_type: str = 'CustomComponent', unsupported_type: str = "CustomComponent",
text: str = '', text: str = "",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -364,7 +358,7 @@ def unsupported_query(
if text: if text:
components.append(platform_message.Plain(text=text)) components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types # Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}')) components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
chain = platform_message.MessageChain(components) chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id) event = friend_message_event(chain, sender_id)
adapter = mock_adapter() adapter = mock_adapter()
@@ -380,7 +374,7 @@ def unsupported_query(
def query_with_session( def query_with_session(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None, session: provider_session.Session = None,
**overrides, **overrides,
@@ -395,7 +389,7 @@ def query_with_session(
launcher_type=provider_session.LauncherTypes.PERSON, launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id, launcher_id=sender_id,
sender_id=sender_id, sender_id=sender_id,
use_prompt_name='default', use_prompt_name="default",
using_conversation=None, using_conversation=None,
conversations=[], conversations=[],
) )
@@ -404,7 +398,7 @@ def query_with_session(
def query_with_config( def query_with_config(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None, pipeline_config: dict = None,
**overrides, **overrides,
@@ -416,22 +410,22 @@ def query_with_config(
""" """
if pipeline_config is None: if pipeline_config is None:
pipeline_config = { pipeline_config = {
'ai': { "ai": {
'runner': {'runner': 'local-agent'}, "runner": {"runner": "local-agent"},
'local-agent': { "local-agent": {
'model': {'primary': 'test-model-uuid', 'fallbacks': []}, "model": {"primary": "test-model-uuid", "fallbacks": []},
'prompt': 'test-prompt', "prompt": "test-prompt",
}, },
}, },
'output': {'misc': {'at-sender': False, 'quote-origin': False}}, "output": {"misc": {"at-sender": False, "quote-origin": False}},
'trigger': {'misc': {'combine-quote-message': False}}, "trigger": {"misc": {"combine-quote-message": False}},
} }
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides) return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query( def voice_query(
url: str = 'https://example.com/audio.mp3', url: str = "https://example.com/audio.mp3",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
**overrides, **overrides,
) -> pipeline_query.Query: ) -> pipeline_query.Query:
@@ -454,7 +448,7 @@ def voice_query(
def at_all_query( def at_all_query(
text: str = 'hello', text: str = "hello",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
**overrides, **overrides,
@@ -462,7 +456,7 @@ def at_all_query(
"""Create a group query with @All mention.""" """Create a group query with @All mention."""
components = [ components = [
platform_message.AtAll(), platform_message.AtAll(),
platform_message.Plain(text=f' {text}'), platform_message.Plain(text=f" {text}"),
] ]
chain = platform_message.MessageChain(components) chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id) event = group_message_event(chain, sender_id, group_id=group_id)
@@ -475,4 +469,4 @@ def at_all_query(
sender_id=sender_id, sender_id=sender_id,
adapter=adapter, adapter=adapter,
**overrides, **overrides,
) )

View File

@@ -33,7 +33,7 @@ class FakePlatform:
def __init__( def __init__(
self, self,
*, *,
bot_account_id: str = 'test-bot', bot_account_id: str = "test-bot",
stream_output_supported: bool = False, stream_output_supported: bool = False,
raise_error: Exception = None, raise_error: Exception = None,
): ):
@@ -48,16 +48,16 @@ class FakePlatform:
# Registered listeners # Registered listeners
self._listeners: dict = {} self._listeners: dict = {}
def raises(self, error: Exception) -> 'FakePlatform': def raises(self, error: Exception) -> "FakePlatform":
"""Configure platform to raise an error on send.""" """Configure platform to raise an error on send."""
self._raise_error = error self._raise_error = error
return self return self
def send_failure(self) -> 'FakePlatform': def send_failure(self) -> "FakePlatform":
"""Configure platform to simulate send failure.""" """Configure platform to simulate send failure."""
return self.raises(Exception('Platform send failure')) return self.raises(Exception("Platform send failure"))
def supports_streaming(self, supported: bool = True) -> 'FakePlatform': def supports_streaming(self, supported: bool = True) -> "FakePlatform":
"""Configure whether streaming output is supported.""" """Configure whether streaming output is supported."""
self._stream_output_supported = supported self._stream_output_supported = supported
return self return self
@@ -89,7 +89,7 @@ class FakePlatform:
self, self,
text: str, text: str,
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
nickname: str = 'TestUser', nickname: str = "TestUser",
) -> platform_events.FriendMessage: ) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event.""" """Create an inbound friend (private) message event."""
sender = platform_entities.Friend( sender = platform_entities.Friend(
@@ -97,13 +97,11 @@ class FakePlatform:
nickname=nickname, nickname=nickname,
remark=None, remark=None,
) )
chain = platform_message.MessageChain( chain = platform_message.MessageChain([
[ platform_message.Plain(text=text),
platform_message.Plain(text=text), ])
]
)
return platform_events.FriendMessage( return platform_events.FriendMessage(
type='FriendMessage', type="FriendMessage",
sender=sender, sender=sender,
message_chain=chain, message_chain=chain,
time=1609459200, time=1609459200,
@@ -113,9 +111,9 @@ class FakePlatform:
self, self,
text: str, text: str,
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
sender_name: str = 'TestUser', sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
group_name: str = 'TestGroup', group_name: str = "TestGroup",
mention_bot: bool = False, mention_bot: bool = False,
) -> platform_events.GroupMessage: ) -> platform_events.GroupMessage:
"""Create an inbound group message event. """Create an inbound group message event.
@@ -144,12 +142,12 @@ class FakePlatform:
components = [] components = []
if mention_bot: if mention_bot:
components.append(platform_message.At(target=self.bot_account_id)) components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=' ')) components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=text)) components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components) chain = platform_message.MessageChain(components)
return platform_events.GroupMessage( return platform_events.GroupMessage(
type='GroupMessage', type="GroupMessage",
sender=sender, sender=sender,
message_chain=chain, message_chain=chain,
time=1609459200, time=1609459200,
@@ -157,8 +155,8 @@ class FakePlatform:
def create_image_message( def create_image_message(
self, self,
url: str = 'https://example.com/image.png', url: str = "https://example.com/image.png",
text: str = '', text: str = "",
sender_id: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345,
is_group: bool = False, is_group: bool = False,
group_id: typing.Union[int, str] = 99999, group_id: typing.Union[int, str] = 99999,
@@ -171,12 +169,12 @@ class FakePlatform:
chain = platform_message.MessageChain(components) chain = platform_message.MessageChain(components)
if is_group: if is_group:
return self.create_group_message('', sender_id, group_id=group_id) return self.create_group_message("", sender_id, group_id=group_id)
# Replace chain # Replace chain
else: else:
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None) sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
return platform_events.FriendMessage( return platform_events.FriendMessage(
type='FriendMessage', type="FriendMessage",
sender=sender, sender=sender,
message_chain=chain, message_chain=chain,
time=1609459200, time=1609459200,
@@ -194,14 +192,12 @@ class FakePlatform:
if self._raise_error: if self._raise_error:
raise self._raise_error raise self._raise_error
self._outbound_messages.append( self._outbound_messages.append({
{ "type": "send",
'type': 'send', "target_type": target_type,
'target_type': target_type, "target_id": target_id,
'target_id': target_id, "message": message,
'message': message, })
}
)
async def reply_message( async def reply_message(
self, self,
@@ -213,15 +209,13 @@ class FakePlatform:
if self._raise_error: if self._raise_error:
raise self._raise_error raise self._raise_error
self._outbound_messages.append( self._outbound_messages.append({
{ "type": "reply",
'type': 'reply', "source_type": message_source.type,
'source_type': message_source.type, "source": message_source,
'source': message_source, "message": message,
'message': message, "quote_origin": quote_origin,
'quote_origin': quote_origin, })
}
)
async def reply_message_chunk( async def reply_message_chunk(
self, self,
@@ -235,17 +229,15 @@ class FakePlatform:
if self._raise_error: if self._raise_error:
raise self._raise_error raise self._raise_error
self._outbound_chunks.append( self._outbound_chunks.append({
{ "type": "reply_chunk",
'type': 'reply_chunk', "source_type": message_source.type,
'source_type': message_source.type, "source": message_source,
'source': message_source, "bot_message": bot_message,
'bot_message': bot_message, "message": message,
'message': message, "quote_origin": quote_origin,
'quote_origin': quote_origin, "is_final": is_final,
'is_final': is_final, })
}
)
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported.""" """Return whether streaming output is supported."""
@@ -303,7 +295,7 @@ class FakePlatform:
def fake_platform( def fake_platform(
bot_account_id: str = 'test-bot', bot_account_id: str = "test-bot",
stream_output_supported: bool = False, stream_output_supported: bool = False,
) -> FakePlatform: ) -> FakePlatform:
"""Create a FakePlatform instance.""" """Create a FakePlatform instance."""
@@ -336,7 +328,9 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
adapter.reply_message = AsyncMock(side_effect=platform.reply_message) adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk) adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message) adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported) adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter._fake_platform = platform # Store for assertions adapter._fake_platform = platform # Store for assertions
return adapter return adapter

View File

@@ -27,51 +27,51 @@ class FakeProvider:
Does not require API keys. Does not require API keys.
""" """
PONG_RESPONSE = 'LANGBOT_FAKE_PONG' PONG_RESPONSE = "LANGBOT_FAKE_PONG"
def __init__( def __init__(
self, self,
*, *,
default_response: str = 'fake response', default_response: str = "fake response",
streaming_chunks: list[str] = None, streaming_chunks: list[str] = None,
raise_error: Exception = None, raise_error: Exception = None,
captured_requests: list = None, captured_requests: list = None,
): ):
self._default_response = default_response self._default_response = default_response
self._streaming_chunks = streaming_chunks or ['fake ', 'response'] self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._raise_error = raise_error self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else [] self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> 'FakeProvider': def returns(self, text: str) -> "FakeProvider":
"""Configure provider to return a specific text response.""" """Configure provider to return a specific text response."""
self._default_response = text self._default_response = text
self._streaming_chunks = [text] self._streaming_chunks = [text]
return self return self
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider': def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
"""Configure provider to return streaming chunks.""" """Configure provider to return streaming chunks."""
self._streaming_chunks = chunks self._streaming_chunks = chunks
self._default_response = ''.join(chunks) self._default_response = "".join(chunks)
return self return self
def raises(self, error: Exception) -> 'FakeProvider': def raises(self, error: Exception) -> "FakeProvider":
"""Configure provider to raise an error.""" """Configure provider to raise an error."""
self._raise_error = error self._raise_error = error
return self return self
def timeout(self) -> 'FakeProvider': def timeout(self) -> "FakeProvider":
"""Configure provider to simulate timeout.""" """Configure provider to simulate timeout."""
return self.raises(TimeoutError('Provider timeout')) return self.raises(TimeoutError("Provider timeout"))
def auth_error(self) -> 'FakeProvider': def auth_error(self) -> "FakeProvider":
"""Configure provider to simulate auth error.""" """Configure provider to simulate auth error."""
return self.raises(Exception('Invalid API key')) return self.raises(Exception("Invalid API key"))
def rate_limit(self) -> 'FakeProvider': def rate_limit(self) -> "FakeProvider":
"""Configure provider to simulate rate limit.""" """Configure provider to simulate rate limit."""
return self.raises(Exception('Rate limit exceeded')) return self.raises(Exception("Rate limit exceeded"))
def malformed(self) -> 'FakeProvider': def malformed(self) -> "FakeProvider":
"""Configure provider to simulate malformed response.""" """Configure provider to simulate malformed response."""
self._default_response = None self._default_response = None
return self return self
@@ -87,7 +87,7 @@ class FakeProvider:
def _create_message(self, content: str) -> provider_message.Message: def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content.""" """Create a provider message from text content."""
return provider_message.Message( return provider_message.Message(
role='assistant', role="assistant",
content=content, content=content,
) )
@@ -99,7 +99,7 @@ class FakeProvider:
) -> provider_message.MessageChunk: ) -> provider_message.MessageChunk:
"""Create a provider message chunk.""" """Create a provider message chunk."""
return provider_message.MessageChunk( return provider_message.MessageChunk(
role='assistant', role="assistant",
content=content, content=content,
is_final=is_final, is_final=is_final,
msg_sequence=msg_sequence, msg_sequence=msg_sequence,
@@ -116,15 +116,13 @@ class FakeProvider:
) -> provider_message.Message: ) -> provider_message.Message:
"""Simulate non-streaming LLM invocation.""" """Simulate non-streaming LLM invocation."""
# Capture request for assertions # Capture request for assertions
self._captured_requests.append( self._captured_requests.append({
{ "query_id": query.query_id if query else None,
'query_id': query.query_id if query else None, "model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, "messages": messages,
'messages': messages, "funcs": funcs,
'funcs': funcs, "extra_args": extra_args,
'extra_args': extra_args, })
}
)
# Simulate error if configured # Simulate error if configured
if self._raise_error: if self._raise_error:
@@ -133,7 +131,7 @@ class FakeProvider:
# Return response # Return response
if self._default_response is None: if self._default_response is None:
# Malformed response # Malformed response
return provider_message.Message(role='assistant', content=None) return provider_message.Message(role="assistant", content=None)
return self._create_message(self._default_response) return self._create_message(self._default_response)
@@ -148,16 +146,14 @@ class FakeProvider:
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation.""" """Simulate streaming LLM invocation."""
# Capture request for assertions # Capture request for assertions
self._captured_requests.append( self._captured_requests.append({
{ "query_id": query.query_id if query else None,
'query_id': query.query_id if query else None, "model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, "messages": messages,
'messages': messages, "funcs": funcs,
'funcs': funcs, "extra_args": extra_args,
'extra_args': extra_args, "streaming": True,
'streaming': True, })
}
)
# Simulate error if configured # Simulate error if configured
if self._raise_error: if self._raise_error:
@@ -165,12 +161,12 @@ class FakeProvider:
# Yield chunks # Yield chunks
for i, chunk in enumerate(self._streaming_chunks): for i, chunk in enumerate(self._streaming_chunks):
is_final = i == len(self._streaming_chunks) - 1 is_final = (i == len(self._streaming_chunks) - 1)
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i) yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider( def fake_provider(
default_response: str = 'fake response', default_response: str = "fake response",
) -> FakeProvider: ) -> FakeProvider:
"""Create a FakeProvider with optional default response.""" """Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response) return FakeProvider(default_response=default_response)
@@ -206,8 +202,8 @@ def fake_provider_malformed() -> FakeProvider:
def fake_model( def fake_model(
*, *,
uuid: str = 'test-model-uuid', uuid: str = "test-model-uuid",
name: str = 'test-model', name: str = "test-model",
abilities: list[str] = None, abilities: list[str] = None,
provider: FakeProvider = None, provider: FakeProvider = None,
) -> Mock: ) -> Mock:
@@ -216,7 +212,7 @@ def fake_model(
model.model_entity = Mock() model.model_entity = Mock()
model.model_entity.uuid = uuid model.model_entity.uuid = uuid
model.model_entity.name = name model.model_entity.name = name
model.model_entity.abilities = abilities or ['func_call', 'vision'] model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.extra_args = {} model.model_entity.extra_args = {}
# Attach fake provider # Attach fake provider
@@ -225,4 +221,4 @@ def fake_model(
model.provider = provider model.provider = provider
return model return model

View File

@@ -3,4 +3,4 @@ Integration tests package.
These tests validate real system behavior with actual database/network resources. These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q Run with: uv run pytest tests/integration/ -m "not slow" -q
""" """

View File

@@ -2,4 +2,4 @@
API integration tests package. API integration tests package.
Tests for HTTP API endpoints using Quart test client. Tests for HTTP API endpoints using Quart test client.
""" """

View File

@@ -48,7 +48,6 @@ def mock_circular_import_chain():
clear=clear, clear=clear,
): ):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield yield
@@ -57,12 +56,10 @@ def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse).""" """Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp() app = FakeApp()
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# Auth services # Auth services
app.user_service = Mock() app.user_service = Mock()
@@ -74,29 +71,28 @@ def fake_bot_app():
# Bot service # Bot service
app.bot_service = Mock() app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock( app.bot_service.get_bots = AsyncMock(return_value=[
return_value=[ {
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
]
)
app.bot_service.get_runtime_bot_info = AsyncMock(
return_value={
'uuid': 'test-bot-uuid', 'uuid': 'test-bot-uuid',
'name': 'Test Bot', 'name': 'Test Bot',
'platform': 'telegram', 'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid', 'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
} }
) ])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'}) app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={}) app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock() app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1)) app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.send_message = AsyncMock() app.bot_service.send_message = AsyncMock()
# Platform manager # Platform manager
@@ -122,7 +118,10 @@ class TestBotEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client): async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list.""" """GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'}) response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200 assert response.status_code == 200
data = await response.get_json() data = await response.get_json()
@@ -136,7 +135,7 @@ class TestBotEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/platform/bots', '/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}, json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -148,7 +147,8 @@ class TestBotEndpoints:
async def test_get_single_bot_success(self, quart_test_client): async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info.""" """GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -162,7 +162,7 @@ class TestBotEndpoints:
response = await quart_test_client.put( response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid', '/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}, json={'name': 'Updated Bot'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -173,7 +173,8 @@ class TestBotEndpoints:
async def test_delete_bot_success(self, quart_test_client): async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot.""" """DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete( response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -189,7 +190,7 @@ class TestBotLogsEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs', '/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}, json={'from_index': -1, 'max_count': 10}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -212,8 +213,8 @@ class TestBotSendMessageEndpoint:
json={ json={
'target_type': 'person', 'target_type': 'person',
'target_id': 'user123', 'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}], 'message_chain': [{'type': 'text', 'text': 'Hello'}]
}, }
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -227,7 +228,7 @@ class TestBotSendMessageEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message', '/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'}, headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}, json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -243,8 +244,8 @@ class TestBotSendMessageEndpoint:
json={ json={
'target_type': 'invalid', 'target_type': 'invalid',
'target_id': 'user123', 'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}], 'message_chain': [{'type': 'text', 'text': 'Hello'}]
}, }
) )
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -47,7 +47,6 @@ def mock_circular_import_chain():
clear=clear, clear=clear,
): ):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield yield
@@ -56,12 +55,10 @@ def fake_embed_app():
"""Create FakeApp with embed widget services (module scope).""" """Create FakeApp with embed widget services (module scope)."""
app = FakeApp() app = FakeApp()
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# Create mock web_page_bot with valid UUID format # Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock() mock_bot_entity = Mock()
@@ -86,7 +83,9 @@ def fake_embed_app():
# WebSocket proxy bot with adapter # WebSocket proxy bot with adapter
mock_websocket_adapter = Mock() mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}]) mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.reset_session = Mock() mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock() mock_websocket_adapter.handle_websocket_message = AsyncMock()
@@ -118,7 +117,9 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client): async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS.""" """GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js') response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
assert response.status_code == 200 assert response.status_code == 200
assert 'javascript' in response.content_type assert 'javascript' in response.content_type
@@ -126,14 +127,18 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client): async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400.""" """GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js') response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
assert response.status_code == 400 assert response.status_code == 400
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client): async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404.""" """GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js') response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
assert response.status_code == 404 assert response.status_code == 404
@@ -159,7 +164,8 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_no_secret(self, quart_test_client): async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token.""" """POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'} '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -171,7 +177,8 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_invalid_uuid(self, quart_test_client): async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400.""" """POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'} '/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -180,7 +187,8 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_missing_token(self, quart_test_client): async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400.""" """POST turnstile verify without token returns 400."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={} '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -195,7 +203,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/person returns messages.""" """GET messages/person returns messages."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -208,7 +216,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/group returns messages.""" """GET messages/group returns messages."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -218,7 +226,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages with invalid session_type returns 400.""" """GET messages with invalid session_type returns 400."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'}
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -233,7 +241,7 @@ class TestEmbedResetEndpoint:
"""POST reset/person resets session.""" """POST reset/person resets session."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -244,7 +252,8 @@ class TestEmbedResetEndpoint:
async def test_reset_session_invalid_uuid(self, quart_test_client): async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400.""" """POST reset with invalid UUID returns 400."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'} '/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -260,7 +269,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}, json={'message_id': 'msg-123', 'feedback_type': 1}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -274,7 +283,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}, json={'message_id': 'msg-123', 'feedback_type': 2}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -285,7 +294,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'}, headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}, json={'message_id': 'msg-123', 'feedback_type': 99}
) )
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -49,7 +49,6 @@ def mock_circular_import_chain():
clear=clear, clear=clear,
): ):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield yield
@@ -58,12 +57,10 @@ def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse).""" """Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp() app = FakeApp()
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# Auth services # Auth services
app.user_service = Mock() app.user_service = Mock()
@@ -75,35 +72,33 @@ def fake_knowledge_app():
# Knowledge service # Knowledge service
app.knowledge_service = Mock() app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock( app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
return_value=[ {
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
]
)
app.knowledge_service.get_knowledge_base = AsyncMock(
return_value={
'uuid': 'test-kb-uuid', 'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base', 'name': 'Test Knowledge Base',
'description': 'Test KB description', 'description': 'Test KB description',
'engine_plugin_id': 'test/engine', 'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
} }
) ])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'}) app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={}) app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock() app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock( app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}] {'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
) ])
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'}) app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock() app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}]) app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
# RAG manager # RAG manager
app.rag_mgr = Mock() app.rag_mgr = Mock()
@@ -129,7 +124,8 @@ class TestKnowledgeBaseEndpoints:
async def test_get_knowledge_bases_success(self, quart_test_client): async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list.""" """GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'} '/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -144,7 +140,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/knowledge/bases', '/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}, json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -156,7 +152,8 @@ class TestKnowledgeBaseEndpoints:
async def test_get_single_knowledge_base_success(self, quart_test_client): async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base.""" """GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -170,7 +167,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.put( response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid', '/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}, json={'name': 'Updated KB'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -181,7 +178,8 @@ class TestKnowledgeBaseEndpoints:
async def test_delete_knowledge_base_success(self, quart_test_client): async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base.""" """DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete( response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -195,7 +193,8 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_get_files_success(self, quart_test_client): async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files.""" """GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'} '/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -209,7 +208,7 @@ class TestKnowledgeBaseFilesEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files', '/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}, json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -221,7 +220,8 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_delete_file_from_knowledge_base(self, quart_test_client): async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}.""" """DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete( response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', '/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}, json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -249,7 +249,9 @@ class TestKnowledgeBaseRetrieveEndpoint:
async def test_retrieve_without_query_returns_error(self, quart_test_client): async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400.""" """POST retrieve without query returns 400."""
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={} '/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
) )
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -46,7 +46,6 @@ def mock_circular_import_chain():
clear=clear, clear=clear,
): ):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield yield
@@ -55,12 +54,10 @@ def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope).""" """Create FakeApp with monitoring services (module scope)."""
app = FakeApp() app = FakeApp()
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email # Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock() app.user_service = Mock()
@@ -70,34 +67,40 @@ def fake_monitoring_app():
# Monitoring service # Monitoring service
app.monitoring_service = Mock() app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock( app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
return_value={ 'total_messages': 100,
'total_messages': 100, 'total_llm_calls': 50,
'total_llm_calls': 50, 'total_sessions': 20,
'total_sessions': 20, 'active_sessions': 5,
'active_sessions': 5, 'total_errors': 2,
'total_errors': 2, })
} app.monitoring_service.get_messages = AsyncMock(return_value=(
) [{'id': 'msg-1', 'content': 'test'}], 100
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100)) ))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50)) app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10)) [{'id': 'llm-1'}], 50
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20)) ))
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2)) app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
app.monitoring_service.get_session_analysis = AsyncMock( [{'id': 'emb-1'}], 10
return_value={ ))
'found': True, app.monitoring_service.get_sessions = AsyncMock(return_value=(
'session_id': 'sess-1', [{'session_id': 'sess-1'}], 20
} ))
) app.monitoring_service.get_errors = AsyncMock(return_value=(
app.monitoring_service.get_message_details = AsyncMock( [{'id': 'err-1'}], 2
return_value={ ))
'found': True, app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'message_id': 'msg-1', 'found': True,
} 'session_id': 'sess-1',
) })
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10}) app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12)) app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}]) app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}]) app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}]) app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
@@ -127,7 +130,8 @@ class TestMonitoringOverviewEndpoint:
async def test_get_overview_success(self, quart_test_client): async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics.""" """GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -143,7 +147,8 @@ class TestMonitoringMessagesEndpoint:
async def test_get_messages_success(self, quart_test_client): async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list.""" """GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -160,7 +165,8 @@ class TestMonitoringLLMCallsEndpoint:
async def test_get_llm_calls_success(self, quart_test_client): async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls.""" """GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -174,7 +180,8 @@ class TestMonitoringEmbeddingCallsEndpoint:
async def test_get_embedding_calls_success(self, quart_test_client): async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls.""" """GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -188,7 +195,8 @@ class TestMonitoringSessionsEndpoint:
async def test_get_sessions_success(self, quart_test_client): async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions.""" """GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -202,7 +210,8 @@ class TestMonitoringErrorsEndpoint:
async def test_get_errors_success(self, quart_test_client): async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors.""" """GET /api/v1/monitoring/errors."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -216,7 +225,8 @@ class TestMonitoringAllDataEndpoint:
async def test_get_all_data_success(self, quart_test_client): async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data.""" """GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -232,7 +242,8 @@ class TestMonitoringDetailsEndpoints:
async def test_get_session_analysis(self, quart_test_client): async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis.""" """GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -241,7 +252,8 @@ class TestMonitoringDetailsEndpoints:
async def test_get_message_details(self, quart_test_client): async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details.""" """GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -255,7 +267,8 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_stats(self, quart_test_client): async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats.""" """GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -264,7 +277,8 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_list(self, quart_test_client): async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback.""" """GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -278,7 +292,8 @@ class TestMonitoringExportEndpoint:
async def test_export_messages(self, quart_test_client): async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV.""" """GET export?type=messages returns CSV."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -288,7 +303,8 @@ class TestMonitoringExportEndpoint:
async def test_export_llm_calls(self, quart_test_client): async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV.""" """GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -297,7 +313,8 @@ class TestMonitoringExportEndpoint:
async def test_export_sessions(self, quart_test_client): async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV.""" """GET export?type=sessions returns CSV."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -306,7 +323,8 @@ class TestMonitoringExportEndpoint:
async def test_export_feedback(self, quart_test_client): async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV.""" """GET export?type=feedback returns CSV."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'} '/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -49,7 +49,6 @@ def mock_circular_import_chain():
): ):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield yield
@@ -58,12 +57,10 @@ def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse).""" """Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp() app = FakeApp()
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# Auth services # Auth services
app.user_service = Mock() app.user_service = Mock()
@@ -75,23 +72,27 @@ def fake_provider_app():
# Provider service # Provider service
app.provider_service = Mock() app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock( app.provider_service.get_providers = AsyncMock(return_value=[
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}] {'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
) ])
app.provider_service.get_provider = AsyncMock( app.provider_service.get_provider = AsyncMock(return_value={
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'} 'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
) })
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid') app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={}) app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock() app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock( app.provider_service.get_provider_model_counts = AsyncMock(return_value={
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0} 'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
) })
# LLM model service # LLM model service
app.llm_model_service = Mock() app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}]) app.llm_model_service.get_llm_models = AsyncMock(return_value=[
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'}) {'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'}) app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={}) app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock() app.llm_model_service.delete_llm_model = AsyncMock()
@@ -132,7 +133,8 @@ class TestProviderEndpoints:
async def test_get_providers_success(self, quart_test_client): async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure.""" """GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -155,7 +157,8 @@ class TestProviderEndpoints:
async def test_get_single_provider_success(self, quart_test_client): async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure.""" """GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -174,7 +177,7 @@ class TestProviderEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/provider/providers', '/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}, json={'name': 'New Provider', 'requester': 'chatcmpl'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -191,7 +194,7 @@ class TestProviderEndpoints:
response = await quart_test_client.put( response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid', '/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}, json={'name': 'Updated Provider'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -202,7 +205,8 @@ class TestProviderEndpoints:
async def test_delete_provider_success(self, quart_test_client): async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider.""" """DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete( response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -211,7 +215,8 @@ class TestProviderEndpoints:
async def test_get_provider_includes_model_counts(self, quart_test_client): async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts.""" """GET provider response includes model counts."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -232,7 +237,8 @@ class TestModelEndpoints:
async def test_get_llm_models_success(self, quart_test_client): async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list.""" """GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -244,7 +250,8 @@ class TestModelEndpoints:
async def test_get_single_llm_model_success(self, quart_test_client): async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model.""" """GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -257,7 +264,7 @@ class TestModelEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/provider/models/llm', '/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}, json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -269,7 +276,8 @@ class TestModelEndpoints:
async def test_delete_llm_model_success(self, quart_test_client): async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model.""" """DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete( response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -283,7 +291,8 @@ class TestEmbeddingModelEndpoints:
async def test_get_embedding_models_success(self, quart_test_client): async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list.""" """GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -297,7 +306,7 @@ class TestEmbeddingModelEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/provider/models/embedding', '/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}, json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -314,7 +323,8 @@ class TestRerankModelEndpoints:
async def test_get_rerank_models_success(self, quart_test_client): async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list.""" """GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'} '/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -328,7 +338,7 @@ class TestRerankModelEndpoints:
response = await quart_test_client.post( response = await quart_test_client.post(
'/api/v1/provider/models/rerank', '/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}, headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}, json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -20,7 +20,6 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ============== # ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def mock_circular_import_chain(): def mock_circular_import_chain():
""" """
@@ -70,7 +69,6 @@ def mock_circular_import_chain():
# ============== FAKE APPLICATION FOR API TESTS ============== # ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture @pytest.fixture
def fake_api_app(): def fake_api_app():
""" """
@@ -81,14 +79,12 @@ def fake_api_app():
app = FakeApp() app = FakeApp()
# API-specific config # API-specific config
app.instance_config.data.update( app.instance_config.data.update({
{ 'api': {'port': 5300},
'api': {'port': 5300}, 'plugin': {'enable_marketplace': True},
'plugin': {'enable_marketplace': True}, 'space': {'url': 'https://space.langbot.app'},
'space': {'url': 'https://space.langbot.app'}, 'system': {'allow_modify_login_info': True, 'limitation': {}},
'system': {'allow_modify_login_info': True, 'limitation': {}}, })
}
)
# API-specific services # API-specific services
app.user_service = Mock() app.user_service = Mock()
@@ -122,7 +118,6 @@ def fake_api_app():
# ============== QUART TEST CLIENT FIXTURE ============== # ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture @pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls): async def quart_test_client(fake_api_app, http_controller_cls):
""" """
@@ -140,7 +135,6 @@ async def quart_test_client(fake_api_app, http_controller_cls):
# ============== API SMOKE TESTS ============== # ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain') @pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint: class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test.""" """Tests for /healthz endpoint - simplest smoke test."""
@@ -228,7 +222,8 @@ class TestProtectedEndpoints:
Protected endpoint returns 401 with invalid token. Protected endpoint returns 401 with invalid token.
""" """
response = await quart_test_client.get( response = await quart_test_client.get(
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'} '/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -259,7 +254,10 @@ class TestInvalidPayload:
""" """
POST with wrong JSON structure returns stable error. POST with wrong JSON structure returns stable error.
""" """
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'}) response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
# Should return error with stable JSON structure # Should return error with stable JSON structure
assert response.status_code in (400, 500, 401) assert response.status_code in (400, 500, 401)

View File

@@ -2,4 +2,4 @@
Persistence integration tests package. Persistence integration tests package.
Tests for database migrations and storage behavior. Tests for database migrations and storage behavior.
""" """

View File

@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
@pytest.fixture @pytest.fixture
def sqlite_db_url(tmp_path): def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file.""" """Create SQLite URL with temporary database file."""
db_file = tmp_path / 'test_migrations.db' db_file = tmp_path / "test_migrations.db"
return f'sqlite+aiosqlite:///{db_file}' return f"sqlite+aiosqlite:///{db_file}"
@pytest.fixture @pytest.fixture
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
# Verify revision # Verify revision
rev = await get_alembic_current(sqlite_engine) rev = await get_alembic_current(sqlite_engine)
assert rev is not None, 'Expected a revision after upgrade' assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration # Head should be the latest migration
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}' assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine): async def test_upgrade_idempotent(self, sqlite_engine):
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
await run_alembic_upgrade(sqlite_engine, 'head') await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine) rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f'Expected {rev1}, got {rev2}' assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestSQLiteMigrationFreshDatabase: class TestSQLiteMigrationFreshDatabase:
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
4. Verify revision 4. Verify revision
""" """
# Use different DB file for fresh test # Use different DB file for fresh test
fresh_db_file = tmp_path / 'test_migrations_fresh.db' fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}' fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url) fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB # Create tables on fresh DB
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
# Verify revision # Verify revision
rev = await get_alembic_current(fresh_engine) rev = await get_alembic_current(fresh_engine)
assert rev is not None, 'Expected a revision on fresh DB' assert rev is not None, "Expected a revision on fresh DB"
await fresh_engine.dispose() await fresh_engine.dispose()
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
IMPORTANT: This test verifies the ACTUAL behavior, not accepting IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass. any arbitrary failure with try-except pass.
""" """
fresh_db_file = tmp_path / 'test_empty_migrations.db' fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}' fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url) fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior # Capture the actual behavior
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
# Verify specific behavior - one of two outcomes is expected # Verify specific behavior - one of two outcomes is expected
if actual_result is not None: if actual_result is not None:
# Migration succeeded - verify revision exists # Migration succeeded - verify revision exists
assert actual_result is not None, 'Revision should exist after successful migration' assert actual_result is not None, "Revision should exist after successful migration"
else: else:
# Migration failed - verify the error type is known # Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables # Alembic typically raises specific errors for missing tables
assert actual_error is not None, 'Error should be captured if migration failed' assert actual_error is not None, "Error should be captured if migration failed"
# Log the error type for documentation (don't silently pass) # Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__ error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios # Acceptable error types for empty DB scenarios
acceptable_errors = [ acceptable_errors = [
'OperationalError', # SQLite table not found 'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors 'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors 'CommandError', # Alembic command errors
] ]
assert error_type in acceptable_errors, ( assert error_type in acceptable_errors, (
f'Unexpected error type: {error_type}. ' f"Unexpected error type: {error_type}. "
f'This may indicate a regression in migration behavior. ' f"This may indicate a regression in migration behavior. "
f'Error: {actual_error}' f"Error: {actual_error}"
) )
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
# No stamp - should return None # No stamp - should return None
rev = await get_alembic_current(sqlite_engine) rev = await get_alembic_current(sqlite_engine)
assert rev is None, f'Expected None for unstamped DB, got {rev}' assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine): async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
await run_alembic_stamp(sqlite_engine, '0001_baseline') await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine) rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline' assert rev == '0001_baseline'

View File

@@ -34,14 +34,14 @@ def postgres_url():
"""Get PostgreSQL URL from environment.""" """Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL') url = os.environ.get('TEST_POSTGRES_URL')
if not url: if not url:
pytest.skip('TEST_POSTGRES_URL not set') pytest.skip("TEST_POSTGRES_URL not set")
return url return url
@pytest.fixture @pytest.fixture
async def postgres_engine(postgres_url): async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine.""" """Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT') engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
yield engine yield engine
await engine.dispose() await engine.dispose()
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn: async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists # Drop alembic_version table if exists
try: try:
await conn.execute(text('DROP TABLE IF EXISTS alembic_version')) await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception: except Exception:
pass pass
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn: async with postgres_engine.begin() as conn:
try: try:
await conn.execute(text('DROP TABLE IF EXISTS alembic_version')) await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception: except Exception:
pass pass
@@ -83,7 +83,9 @@ class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL.""" """Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version): async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
""" """
Stamp baseline on existing tables sets correct revision. Stamp baseline on existing tables sets correct revision.
@@ -104,7 +106,9 @@ class TestPostgreSQLMigrationBaseline:
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}" assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version): async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
""" """
Stamp on empty database (no tables) still sets revision. Stamp on empty database (no tables) still sets revision.
@@ -121,7 +125,9 @@ class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL.""" """Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version): async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
""" """
Upgrade from baseline to head applies all migrations. Upgrade from baseline to head applies all migrations.
@@ -143,12 +149,14 @@ class TestPostgreSQLMigrationUpgrade:
# Verify revision # Verify revision
rev = await get_alembic_current(postgres_engine) rev = await get_alembic_current(postgres_engine)
assert rev is not None, 'Expected a revision after upgrade' assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration (0005 for current state) # Head should be the latest migration (0005 for current state)
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}' assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version): async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
""" """
Running upgrade to head multiple times is idempotent. Running upgrade to head multiple times is idempotent.
@@ -172,7 +180,7 @@ class TestPostgreSQLMigrationUpgrade:
await run_alembic_upgrade(postgres_engine, 'head') await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine) rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f'Expected {rev1}, got {rev2}' assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestPostgreSQLMigrationGetCurrent: class TestPostgreSQLMigrationGetCurrent:
@@ -191,7 +199,7 @@ class TestPostgreSQLMigrationGetCurrent:
# No stamp - should return None # No stamp - should return None
rev = await get_alembic_current(postgres_engine) rev = await get_alembic_current(postgres_engine)
assert rev is None, f'Expected None for unstamped DB, got {rev}' assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision( async def test_postgres_get_current_after_stamp_returns_revision(
@@ -206,4 +214,4 @@ class TestPostgreSQLMigrationGetCurrent:
await run_alembic_stamp(postgres_engine, '0001_baseline') await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine) rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline' assert rev == '0001_baseline'

View File

@@ -2,4 +2,4 @@
Pipeline integration tests package. Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner. Tests for full pipeline flow using fake provider/runner.
""" """

View File

@@ -26,7 +26,6 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ============== # ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def mock_circular_import_chain(): def mock_circular_import_chain():
""" """
@@ -104,7 +103,6 @@ def mock_circular_import_chain():
# ============== FAKE RUNNER ============== # ============== FAKE RUNNER ==============
class FakeRunner: class FakeRunner:
"""Minimal fake runner class for pipeline integration tests. """Minimal fake runner class for pipeline integration tests.
@@ -119,13 +117,12 @@ class FakeRunner:
self.config = config or {} self.config = config or {}
self._provider = FakeProvider() self._provider = FakeProvider()
# Instance-level configuration set via class attribute # Instance-level configuration set via class attribute
self._response_text = 'fake response' self._response_text = "fake response"
self._raise_error = None self._raise_error = None
@classmethod @classmethod
def returns(cls, text: str): def returns(cls, text: str):
"""Create a runner class configured to return specific text.""" """Create a runner class configured to return specific text."""
# We create a subclass with configured response # We create a subclass with configured response
class ConfiguredRunner(cls): class ConfiguredRunner(cls):
name = cls.name name = cls.name
@@ -135,13 +132,11 @@ class FakeRunner:
def __init__(self, app=None, config=None): def __init__(self, app=None, config=None):
super().__init__(app, config) super().__init__(app, config)
self._response_text = text self._response_text = text
return ConfiguredRunner return ConfiguredRunner
@classmethod @classmethod
def raises(cls, error: Exception): def raises(cls, error: Exception):
"""Create a runner class configured to raise an error.""" """Create a runner class configured to raise an error."""
class ConfiguredRunner(cls): class ConfiguredRunner(cls):
name = cls.name name = cls.name
_response_text = None _response_text = None
@@ -150,7 +145,6 @@ class FakeRunner:
def __init__(self, app=None, config=None): def __init__(self, app=None, config=None):
super().__init__(app, config) super().__init__(app, config)
self._raise_error = error self._raise_error = error
return ConfiguredRunner return ConfiguredRunner
async def run(self, query): async def run(self, query):
@@ -167,7 +161,6 @@ class FakeRunner:
# ============== PIPELINE APP FIXTURE ============== # ============== PIPELINE APP FIXTURE ==============
@pytest.fixture @pytest.fixture
def pipeline_app(): def pipeline_app():
""" """
@@ -194,7 +187,6 @@ def pipeline_app():
def __init__(self, name, messages): def __init__(self, name, messages):
self.name = name self.name = name
self.messages = messages self.messages = messages
def copy(self): def copy(self):
return MockPrompt(self.name, list(self.messages)) return MockPrompt(self.name, list(self.messages))
@@ -245,17 +237,14 @@ def fake_platform_adapter():
@pytest.fixture @pytest.fixture
def set_fake_runner(): def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners.""" """Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls): def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes # preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls] sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner return _set_runner
# ============== PIPELINE CONFIGURATION ============== # ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config(): def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests.""" """Create minimal pipeline configuration for tests."""
return { return {
@@ -284,7 +273,6 @@ def create_minimal_pipeline_config():
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ============== # ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name): async def collect_processor_results(processor, query, stage_name):
""" """
Helper to handle the coroutine -> async_generator pattern. Helper to handle the coroutine -> async_generator pattern.
@@ -308,7 +296,6 @@ async def collect_processor_results(processor, query, stage_name):
# ============== TESTS ============== # ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain') @pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal: class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain.""" """Tests for real pipeline stage chain."""
@@ -350,7 +337,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Create query with adapter # Create query with adapter
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
@@ -378,7 +365,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
query = text_query('test message content') query = text_query("test message content")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
@@ -409,11 +396,11 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Set fake runner that returns pong # Set fake runner that returns pong
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG') fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner) set_fake_runner(fake_runner)
# Create query # Create query
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = [] query.resp_messages = []
@@ -427,7 +414,6 @@ class TestProcessorStage:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -446,7 +432,7 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Create query # Create query
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
@@ -459,7 +445,6 @@ class TestProcessorStage:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -477,13 +462,13 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Create query # Create query
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = [] query.resp_messages = []
# Create reply chain # Create reply chain
reply_chain = text_chain('plugin response') reply_chain = text_chain("plugin response")
# Mock plugin_connector to prevent default with reply # Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock() mock_event_ctx = Mock()
@@ -494,7 +479,6 @@ class TestProcessorStage:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -518,7 +502,7 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Set fake runner that raises exception # Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded')) fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
set_fake_runner(fake_runner) set_fake_runner(fake_runner)
# Create query with exception handling config # Create query with exception handling config
@@ -526,7 +510,7 @@ class TestRunnerExceptionFlow:
config['output']['misc']['exception-handling'] = 'show-hint' config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.' config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = config query.pipeline_config = config
@@ -539,7 +523,6 @@ class TestRunnerExceptionFlow:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -558,14 +541,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception # Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error')) fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
set_fake_runner(fake_runner) set_fake_runner(fake_runner)
# Create query with show-error mode # Create query with show-error mode
config = create_minimal_pipeline_config() config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error' config['output']['misc']['exception-handling'] = 'show-error'
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = config query.pipeline_config = config
@@ -578,7 +561,6 @@ class TestRunnerExceptionFlow:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -596,14 +578,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Set fake runner that raises exception # Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception('Hidden error')) fake_runner = FakeRunner().raises(Exception("Hidden error"))
set_fake_runner(fake_runner) set_fake_runner(fake_runner)
# Create query with hide mode # Create query with hide mode
config = create_minimal_pipeline_config() config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide' config['output']['misc']['exception-handling'] = 'hide'
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = config query.pipeline_config = config
@@ -616,7 +598,6 @@ class TestRunnerExceptionFlow:
# Create Processor stage # Create Processor stage
from langbot.pkg.pipeline.process import process from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app) processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config) await processor_stage.initialize(query.pipeline_config)
@@ -642,7 +623,7 @@ class TestSendResponseBackStage:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Create query with response message # Create query with response message
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
@@ -685,12 +666,12 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Set fake runner # Set fake runner
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG') fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner) set_fake_runner(fake_runner)
# Create query # Create query
config = create_minimal_pipeline_config() config = create_minimal_pipeline_config()
query = text_query('ping') query = text_query("ping")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = config query.pipeline_config = config
query.resp_messages = [] query.resp_messages = []
@@ -709,7 +690,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock() pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [ pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived mock_event_ctx_processor, # Processor NormalMessageReceived
] ]
@@ -730,7 +711,6 @@ class TestStageChainIntegration:
# Build resp_message_chain from resp_messages # Build resp_message_chain from resp_messages
from tests.factories.message import text_chain from tests.factories.message import text_chain
for resp_msg in query.resp_messages: for resp_msg in query.resp_messages:
if resp_msg.content: if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content)) query.resp_message_chain.append(text_chain(resp_msg.content))
@@ -757,7 +737,7 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter adapter, platform = fake_platform_adapter
# Create query # Create query
query = text_query('hello') query = text_query("hello")
query.adapter = adapter query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config() query.pipeline_config = create_minimal_pipeline_config()
@@ -774,7 +754,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock() pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [ pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived mock_event_ctx_processor, # Processor NormalMessageReceived
] ]
@@ -795,4 +775,4 @@ class TestStageChainIntegration:
assert results[0].result_type == entities.ResultType.INTERRUPT assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages # Chain stops here - no resp_messages
assert len(query.resp_messages) == 0 assert len(query.resp_messages) == 0

View File

@@ -3,4 +3,4 @@ Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases. Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q Run with: uv run pytest tests/smoke/ -q
""" """

View File

@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
assert app.instance_config is not None assert app.instance_config is not None
# Verify default config # Verify default config
assert app.instance_config.data['command']['prefix'] == ['/', '!'] assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data['command']['enable'] is True assert app.instance_config.data["command"]["enable"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_provider_returns_text(self): async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response.""" """Test FakeProvider returns configured response."""
provider = FakeProvider(default_response='test response') provider = FakeProvider(default_response="test response")
# Create mock model with provider # Create mock model with provider
model = fake_model(provider=provider) model = fake_model(provider=provider)
# Create a simple query # Create a simple query
query = text_query('hello') query = text_query("hello")
# Simulate invoke # Simulate invoke
result = await provider.invoke_llm( result = await provider.invoke_llm(
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
) )
assert result is not None assert result is not None
assert result.role == 'assistant' assert result.role == "assistant"
assert result.content == 'test response' assert result.content == "test response"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_provider_pong(self): async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker.""" """Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong() provider = fake_provider_pong()
model = fake_model(provider=provider) model = fake_model(provider=provider)
query = text_query('ping') query = text_query("ping")
result = await provider.invoke_llm( result = await provider.invoke_llm(
query=query, query=query,
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_provider_streaming(self): async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response.""" """Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(['Hello', ' World']) provider = FakeProvider().returns_streaming(["Hello", " World"])
model = fake_model(provider=provider) model = fake_model(provider=provider)
query = text_query('hello') query = text_query("hello")
chunks = [] chunks = []
# invoke_llm_stream returns an async generator, don't await it # invoke_llm_stream returns an async generator, don't await it
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 2 assert len(chunks) == 2
assert chunks[0].content == 'Hello' assert chunks[0].content == "Hello"
assert chunks[1].content == ' World' assert chunks[1].content == " World"
assert chunks[1].is_final is True assert chunks[1].is_final is True
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates timeout error.""" """Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout() provider = FakeProvider().timeout()
model = fake_model(provider=provider) model = fake_model(provider=provider)
query = text_query('hello') query = text_query("hello")
with pytest.raises(TimeoutError, match='Provider timeout'): with pytest.raises(TimeoutError, match="Provider timeout"):
await provider.invoke_llm( await provider.invoke_llm(
query=query, query=query,
model=model, model=model,
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates rate limit error.""" """Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit() provider = FakeProvider().rate_limit()
model = fake_model(provider=provider) model = fake_model(provider=provider)
query = text_query('hello') query = text_query("hello")
with pytest.raises(Exception, match='Rate limit exceeded'): with pytest.raises(Exception, match="Rate limit exceeded"):
await provider.invoke_llm( await provider.invoke_llm(
query=query, query=query,
model=model, model=model,
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
async def test_fake_provider_captures_requests(self): async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments.""" """Test FakeProvider captures request arguments."""
provider = FakeProvider() provider = FakeProvider()
model = fake_model(name='gpt-4', provider=provider) model = fake_model(name="gpt-4", provider=provider)
query = text_query('hello') query = text_query("hello")
await provider.invoke_llm( await provider.invoke_llm(
query=query, query=query,
model=model, model=model,
messages=[{'role': 'user', 'content': 'hello'}], messages=[{"role": "user", "content": "hello"}],
funcs=[{'name': 'test_func'}], funcs=[{"name": "test_func"}],
extra_args={'temperature': 0.7}, extra_args={"temperature": 0.7},
) )
captured = provider.get_captured_requests() captured = provider.get_captured_requests()
assert len(captured) == 1 assert len(captured) == 1
assert captured[0]['model'] == 'gpt-4' assert captured[0]["model"] == "gpt-4"
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}] assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]['funcs'] == [{'name': 'test_func'}] assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]['extra_args'] == {'temperature': 0.7} assert captured[0]["extra_args"] == {"temperature": 0.7}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self): async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages.""" """Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id='test-bot') platform = FakePlatform(bot_account_id="test-bot")
query = text_query('hello') query = text_query("hello")
# Simulate sending reply # Simulate sending reply
from tests.factories.message import text_chain from tests.factories.message import text_chain
reply_chain = text_chain('response text') reply_chain = text_chain("response text")
event = query.message_event event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False) await platform.reply_message(event, reply_chain, quote_origin=False)
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
# Verify captured # Verify captured
outbound = platform.get_outbound_messages() outbound = platform.get_outbound_messages()
assert len(outbound) == 1 assert len(outbound) == 1
assert outbound[0]['type'] == 'reply' assert outbound[0]["type"] == "reply"
assert outbound[0]['message'] == reply_chain assert outbound[0]["message"] == reply_chain
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_platform_friend_message(self): async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events.""" """Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id='test-bot') platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_friend_message( event = platform.create_friend_message(
text='hello bot', text="hello bot",
sender_id=12345, sender_id=12345,
nickname='TestUser', nickname="TestUser",
) )
assert event.type == 'FriendMessage' assert event.type == "FriendMessage"
assert event.sender.id == 12345 assert event.sender.id == 12345
assert event.sender.nickname == 'TestUser' assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == 'hello bot' assert str(event.message_chain) == "hello bot"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self): async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention.""" """Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id='test-bot') platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_group_message( event = platform.create_group_message(
text='hello everyone', text="hello everyone",
sender_id=12345, sender_id=12345,
group_id=99999, group_id=99999,
mention_bot=True, mention_bot=True,
) )
assert event.type == 'GroupMessage' assert event.type == "GroupMessage"
assert event.sender.id == 12345 assert event.sender.id == 12345
assert event.group.id == 99999 assert event.group.id == 99999
@@ -220,57 +220,54 @@ class TestFakeMessageFlow:
async def test_query_factories_basic(self): async def test_query_factories_basic(self):
"""Test basic query factory functions.""" """Test basic query factory functions."""
# Text query # Text query
q1 = text_query('hello world') q1 = text_query("hello world")
assert q1.launcher_type.value == 'person' assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == 'hello world' assert str(q1.message_chain) == "hello world"
# Group query # Group query
from tests.factories import group_text_query from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
q2 = group_text_query('hello group', group_id=88888) assert q2.launcher_type.value == "group"
assert q2.launcher_type.value == 'group'
assert q2.launcher_id == 88888 assert q2.launcher_id == 88888
# Command query # Command query
from tests.factories import command_query from tests.factories import command_query
q3 = command_query("help", prefix="/")
q3 = command_query('help', prefix='/') assert str(q3.message_chain) == "/help"
assert str(q3.message_chain) == '/help'
# Mention query # Mention query
from tests.factories import mention_query from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
q4 = mention_query('hi', target='test-bot', group_id=77777) assert q4.launcher_type.value == "group"
assert q4.launcher_type.value == 'group'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_platform_send_failure(self): async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure.""" """Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure() platform = FakePlatform().send_failure()
query = text_query('hello') query = text_query("hello")
from tests.factories.message import text_chain from tests.factories.message import text_chain
with pytest.raises(Exception, match='Platform send failure'): with pytest.raises(Exception, match="Platform send failure"):
await platform.reply_message( await platform.reply_message(
query.message_event, query.message_event,
text_chain('response'), text_chain("response"),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mock_platform_adapter(self): async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper.""" """Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id='bot-123') platform = FakePlatform(bot_account_id="bot-123")
adapter = mock_platform_adapter(platform) adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == 'bot-123' assert adapter.bot_account_id == "bot-123"
assert adapter._fake_platform is platform assert adapter._fake_platform is platform
# Test reply_message is wired # Test reply_message is wired
from tests.factories.message import text_chain from tests.factories.message import text_chain
query = text_query('test') query = text_query("test")
await adapter.reply_message(query.message_event, text_chain('response')) await adapter.reply_message(query.message_event, text_chain("response"))
# Verify platform captured it # Verify platform captured it
assert len(platform.get_outbound_messages()) == 1 assert len(platform.get_outbound_messages()) == 1
@@ -296,18 +293,18 @@ class TestMessageFlowIntegration:
Note: This does NOT run actual LangBot pipeline stages. Note: This does NOT run actual LangBot pipeline stages.
""" """
# Setup # Setup
platform = FakePlatform(bot_account_id='test-bot') platform = FakePlatform(bot_account_id="test-bot")
provider = fake_provider_pong() provider = fake_provider_pong()
model = fake_model(provider=provider) model = fake_model(provider=provider)
# Create inbound message # Create inbound message
query = text_query('ping') query = text_query("ping")
# Simulate provider processing # Simulate provider processing
response = await provider.invoke_llm( response = await provider.invoke_llm(
query=query, query=query,
model=model, model=model,
messages=[{'role': 'user', 'content': 'ping'}], messages=[{"role": "user", "content": "ping"}],
funcs=[], funcs=[],
extra_args={}, extra_args={},
) )
@@ -324,16 +321,16 @@ class TestMessageFlowIntegration:
# Verify platform captured outbound # Verify platform captured outbound
outbound = platform.get_outbound_messages() outbound = platform.get_outbound_messages()
assert len(outbound) == 1 assert len(outbound) == 1
assert outbound[0]['type'] == 'reply' assert outbound[0]["type"] == "reply"
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_streaming_message_flow(self): async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow.""" """Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming() platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(['Hello', ' there']) provider = FakeProvider().returns_streaming(["Hello", " there"])
model = fake_model(provider=provider) model = fake_model(provider=provider)
query = text_query('hi') query = text_query("hi")
chunks = [] chunks = []
async for chunk in provider.invoke_llm_stream( async for chunk in provider.invoke_llm_stream(
@@ -347,8 +344,8 @@ class TestMessageFlowIntegration:
# Verify streaming worked # Verify streaming worked
assert len(chunks) == 2 assert len(chunks) == 2
full_content = ''.join(c.content for c in chunks) full_content = "".join(c.content for c in chunks)
assert full_content == 'Hello there' assert full_content == "Hello there"
# Verify platform supports streaming # Verify platform supports streaming
assert await platform.is_stream_output_supported() is True assert await platform.is_stream_output_supported() is True

View File

@@ -15,12 +15,22 @@ import pathlib
# Resolve project root (one level up from tests/) # Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent _PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py' VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
def test_no_exec_call_in_system_controller(): def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input.""" """Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, 'r') as f: with open(VULN_FILE, "r") as f:
source = f.read() source = f.read()
tree = ast.parse(source) tree = ast.parse(source)
@@ -30,26 +40,27 @@ def test_no_exec_call_in_system_controller():
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
func = node.func func = node.func
# Match bare exec() call # Match bare exec() call
if isinstance(func, ast.Name) and func.id == 'exec': if isinstance(func, ast.Name) and func.id == "exec":
exec_calls.append(node.lineno) exec_calls.append(node.lineno)
assert len(exec_calls) == 0, ( assert len(exec_calls) == 0, (
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().' f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
) )
def test_no_debug_exec_route(): def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered.""" """Verify the /debug/exec route is not registered."""
with open(VULN_FILE, 'r') as f: with open(VULN_FILE, "r") as f:
source = f.read() source = f.read()
assert 'debug/exec' not in source, ( assert "debug/exec" not in source, (
'The /debug/exec route still exists in system.py. ' "The /debug/exec route still exists in system.py. "
'This endpoint allows arbitrary code execution and must be removed.' "This endpoint allows arbitrary code execution and must be removed."
) )
if __name__ == '__main__': if __name__ == "__main__":
test_no_exec_call_in_system_controller() test_no_exec_call_in_system_controller()
test_no_debug_exec_route() test_no_debug_exec_route()
print('All tests passed!') print("All tests passed!")

View File

@@ -1 +1 @@
"""Unit tests for LangBot API HTTP service layer.""" """Unit tests for LangBot API HTTP service layer."""

View File

@@ -13,4 +13,4 @@ Does NOT:
- Call real provider/platform/network - Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application. Uses tests.factories.FakeApp as base mock application.
""" """

View File

@@ -132,7 +132,9 @@ class TestApiKeyServiceCreateApiKey:
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'): with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description') result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}] assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert result['key'].startswith('lbk_') assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token' assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key' assert result['name'] == 'New Key'

View File

@@ -303,7 +303,13 @@ class TestBotServiceCreateBot:
ap = SimpleNamespace() ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace() ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace() ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}} ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.platform_mgr = SimpleNamespace() ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock() ap.platform_mgr.load_bot = AsyncMock()
@@ -312,7 +318,9 @@ class TestBotServiceCreateBot:
bot2 = _create_mock_bot(bot_uuid='uuid-2') bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2]) mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
service = BotService(ap) service = BotService(ap)
@@ -344,7 +352,6 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot()) bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -355,7 +362,9 @@ class TestBotServiceCreateBot:
return bot_result # Get bot return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
service = BotService(ap) service = BotService(ap)
@@ -388,7 +397,6 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot()) bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -484,7 +492,6 @@ class TestBotServiceUpdateBot:
pipeline_result.first = Mock(return_value=mock_pipeline) pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -575,9 +582,10 @@ class TestBotServiceListEventLogs:
# Mock runtime bot with logger # Mock runtime bot with logger
runtime_bot = SimpleNamespace() runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace() runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock( runtime_bot.logger.get_logs = AsyncMock(return_value=(
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5) [SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
) 5
))
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot) ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap) service = BotService(ap)
@@ -638,7 +646,11 @@ class TestBotServiceSendMessage:
service = BotService(ap) service = BotService(ap)
# Execute with valid message chain format # Execute with valid message chain format
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]} message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
# Patch the import location - the module imports inside the function # Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain: with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:

View File

@@ -6,7 +6,6 @@ Tests cover:
- Knowledge engine discovery - Knowledge engine discovery
- File operations - File operations
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -53,7 +52,9 @@ class TestGetKnowledgeBases:
"""Test that it returns all knowledge base details.""" """Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module() knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}]) mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases() result = await service.get_knowledge_bases()
@@ -82,7 +83,9 @@ class TestGetKnowledgeBase:
"""Test that it returns specific KB details.""" """Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module() knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'}) mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1') result = await service.get_knowledge_base('kb1')
@@ -150,7 +153,9 @@ class TestCreateKnowledgeBase:
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'}) await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
# Check that default name 'Untitled' was used # Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args call_args = mock_app.rag_mgr.create_knowledge_base.call_args
@@ -165,21 +170,20 @@ class TestUpdateKnowledgeBase:
"""Test that only mutable fields are updated.""" """Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module() knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'}) mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock() mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock() mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields # Pass both mutable and immutable fields
await service.update_knowledge_base( await service.update_knowledge_base('kb1', {
'kb1', 'name': 'New Name',
{ 'description': 'New desc',
'name': 'New Name', 'uuid': 'should_be_filtered', # immutable
'description': 'New desc', })
'uuid': 'should_be_filtered', # immutable
},
)
# Check that only mutable fields were passed to update # Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args call_args = mock_app.persistence_mgr.execute_async.call_args
@@ -284,7 +288,9 @@ class TestListKnowledgeEngines:
"""Test that it returns empty list and logs warning on exception.""" """Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module() knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error')) mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines() result = await service.list_knowledge_engines()
@@ -380,10 +386,12 @@ class TestGetEngineSchemas:
"""Test that it returns empty dict and logs warning on exception.""" """Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module() knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error')) mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
service = knowledge_module.KnowledgeService(mock_app) service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine') result = await service.get_engine_creation_schema('author/engine')
assert result == {} assert result == {}
mock_app.logger.warning.assert_called_once() mock_app.logger.warning.assert_called_once()

View File

@@ -174,7 +174,9 @@ class TestMaintenanceServiceGetStorageAnalysis:
# Setup # Setup
ap = SimpleNamespace() ap = SimpleNamespace()
ap.instance_config = SimpleNamespace() ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}} ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.persistence_mgr = SimpleNamespace() ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace() ap.logger = SimpleNamespace()
ap.logger.warning = Mock() ap.logger.warning = Mock()
@@ -290,8 +292,12 @@ class TestMaintenanceServiceGetStorageAnalysis:
service._file_count = Mock(return_value=0) service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={}) service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0}) service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}]) service._expired_uploaded_candidates = AsyncMock(return_value=[
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}]) {'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
# Execute # Execute
result = await service.get_storage_analysis() result = await service.get_storage_analysis()
@@ -361,7 +367,6 @@ class TestMaintenanceServiceBinaryStorageStats:
size_result = _create_mock_result(scalar_value=5000) size_result = _create_mock_result(scalar_value=5000)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -391,7 +396,6 @@ class TestMaintenanceServiceBinaryStorageStats:
count_result = _create_mock_result(scalar_value=5) count_result = _create_mock_result(scalar_value=5)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -817,4 +821,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
result = service._expired_local_upload_candidates(7, include_paths=True) result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included # Verify - path included
assert 'path' in result[0] assert 'path' in result[0]

View File

@@ -186,7 +186,13 @@ class TestMCPServiceCreateMCPServer:
ap = SimpleNamespace() ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace() ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace() ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}} ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': 2
}
}
}
ap.plugin_connector = SimpleNamespace() ap.plugin_connector = SimpleNamespace()
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
@@ -246,7 +252,6 @@ class TestMCPServiceCreateMCPServer:
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True) server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -356,7 +361,6 @@ class TestMCPServiceUpdateMCPServer:
old_server = _create_mock_mcp_server(name='Old Server', enable=True) old_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -390,7 +394,6 @@ class TestMCPServiceUpdateMCPServer:
updated_server = _create_mock_mcp_server(name='Old Server', enable=True) updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -429,7 +432,6 @@ class TestMCPServiceUpdateMCPServer:
# Mock for: first select -> update -> second select (for updated server) # Mock for: first select -> update -> second select (for updated server)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -463,7 +465,6 @@ class TestMCPServiceUpdateMCPServer:
# Mock execute for select and update # Mock execute for select and update
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -498,7 +499,6 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Server to Delete') server = _create_mock_mcp_server(name='Server to Delete')
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -530,7 +530,6 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Not in Sessions') server = _create_mock_mcp_server(name='Not in Sessions')
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -560,7 +559,6 @@ class TestMCPServiceDeleteMCPServer:
# No server found # No server found
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -598,7 +596,9 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session) ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
ap.task_mgr = SimpleNamespace() ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123)) ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=123)
)
service = MCPService(ap) service = MCPService(ap)
@@ -634,7 +634,9 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session) ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace() ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456)) ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=456)
)
service = MCPService(ap) service = MCPService(ap)
@@ -643,4 +645,4 @@ class TestMCPServiceTestMCPServer:
# Verify - load_mcp_server called # Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once() ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456 assert task_id == 456

View File

@@ -167,7 +167,6 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([]) mock_provider_result = _create_mock_result([])
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
return mock_result if call_count == 0 else mock_provider_result return mock_result if call_count == 0 else mock_provider_result
@@ -201,7 +200,6 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider]) mock_provider_result = _create_mock_result([provider])
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -241,7 +239,6 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider]) mock_provider_result = _create_mock_result([provider])
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -282,7 +279,6 @@ class TestLLMModelsServiceGetLLMModel:
mock_provider_result = _create_mock_result([], first_item=provider) mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -341,7 +337,9 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
mock_result = _create_mock_result([model1, model2]) mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'model-1', 'name': 'Model 1'}
)
service = LLMModelsService(ap) service = LLMModelsService(ap)
@@ -373,14 +371,12 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap) service = LLMModelsService(ap)
# Execute # Execute
model_uuid = await service.create_llm_model( model_uuid = await service.create_llm_model({
{ 'name': 'New LLM',
'name': 'New LLM', 'provider_uuid': 'provider-uuid',
'provider_uuid': 'provider-uuid', 'abilities': [],
'abilities': [], 'extra_args': {},
'extra_args': {}, })
}
)
# Verify # Verify
assert model_uuid is not None assert model_uuid is not None
@@ -404,16 +400,13 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap) service = LLMModelsService(ap)
# Execute # Execute
model_uuid = await service.create_llm_model( model_uuid = await service.create_llm_model({
{ 'uuid': 'preserved-uuid',
'uuid': 'preserved-uuid', 'name': 'Preserved UUID Model',
'name': 'Preserved UUID Model', 'provider_uuid': 'provider-uuid',
'provider_uuid': 'provider-uuid', 'abilities': [],
'abilities': [], 'extra_args': {},
'extra_args': {}, }, preserve_uuid=True)
},
preserve_uuid=True,
)
# Verify # Verify
assert model_uuid == 'preserved-uuid' assert model_uuid == 'preserved-uuid'
@@ -466,14 +459,12 @@ class TestLLMModelsServiceCreateLLMModel:
# Execute & Verify # Execute & Verify
with pytest.raises(Exception, match='provider not found'): with pytest.raises(Exception, match='provider not found'):
await service.create_llm_model( await service.create_llm_model({
{ 'name': 'No Provider Model',
'name': 'No Provider Model', 'provider_uuid': 'nonexistent-provider',
'provider_uuid': 'nonexistent-provider', 'abilities': [],
'abilities': [], 'extra_args': {},
'extra_args': {}, })
}
)
async def test_create_llm_model_with_provider_data(self): async def test_create_llm_model_with_provider_data(self):
"""Creates provider when provider data provided.""" """Creates provider when provider data provided."""
@@ -499,18 +490,16 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap) service = LLMModelsService(ap)
# Execute - with provider data (no UUID) # Execute - with provider data (no UUID)
result_uuid = await service.create_llm_model( result_uuid = await service.create_llm_model({
{ 'name': 'Model with New Provider',
'name': 'Model with New Provider', 'provider': {
'provider': { 'requester': 'openai',
'requester': 'openai', 'base_url': 'https://api.openai.com',
'base_url': 'https://api.openai.com', 'api_keys': ['key'],
'api_keys': ['key'], },
}, 'abilities': [],
'abilities': [], 'extra_args': {},
'extra_args': {}, })
}
)
# Verify - provider_service was called and UUID generated # Verify - provider_service was called and UUID generated
ap.provider_service.find_or_create_provider.assert_called_once() ap.provider_service.find_or_create_provider.assert_called_once()
@@ -536,14 +525,11 @@ class TestLLMModelsServiceUpdateLLMModel:
service = LLMModelsService(ap) service = LLMModelsService(ap)
# Execute # Execute
await service.update_llm_model( await service.update_llm_model('existing-uuid', {
'existing-uuid', 'uuid': 'should-be-removed',
{ 'name': 'Updated Name',
'uuid': 'should-be-removed', 'provider_uuid': 'provider-uuid',
'name': 'Updated Name', })
'provider_uuid': 'provider-uuid',
},
)
# Verify - remove and load called # Verify - remove and load called
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid') ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
@@ -563,13 +549,10 @@ class TestLLMModelsServiceUpdateLLMModel:
# Execute & Verify # Execute & Verify
with pytest.raises(Exception, match='provider not found'): with pytest.raises(Exception, match='provider not found'):
await service.update_llm_model( await service.update_llm_model('model-uuid', {
'model-uuid', 'name': 'Update',
{ 'provider_uuid': 'nonexistent-provider',
'name': 'Update', })
'provider_uuid': 'nonexistent-provider',
},
)
async def test_update_llm_model_reloads_context_length_as_column(self): async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args.""" """Updates runtime model with context_length outside extra_args."""
@@ -635,7 +618,9 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_result = _create_mock_result([]) mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
)
service = EmbeddingModelsService(ap) service = EmbeddingModelsService(ap)
@@ -658,7 +643,6 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_provider_result = _create_mock_result([provider]) mock_provider_result = _create_mock_result([provider])
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -699,7 +683,6 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
mock_provider_result = _create_mock_result([], first_item=provider) mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -759,13 +742,11 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
service = EmbeddingModelsService(ap) service = EmbeddingModelsService(ap)
# Execute # Execute
model_uuid = await service.create_embedding_model( model_uuid = await service.create_embedding_model({
{ 'name': 'New Embedding',
'name': 'New Embedding', 'provider_uuid': 'provider-uuid',
'provider_uuid': 'provider-uuid', 'extra_args': {},
'extra_args': {}, })
}
)
# Verify # Verify
assert model_uuid is not None assert model_uuid is not None
@@ -786,13 +767,11 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
# Execute & Verify # Execute & Verify
with pytest.raises(Exception, match='provider not found'): with pytest.raises(Exception, match='provider not found'):
await service.create_embedding_model( await service.create_embedding_model({
{ 'name': 'No Provider Embedding',
'name': 'No Provider Embedding', 'provider_uuid': 'nonexistent',
'provider_uuid': 'nonexistent', 'extra_args': {},
'extra_args': {}, })
}
)
class TestEmbeddingModelsServiceDeleteEmbeddingModel: class TestEmbeddingModelsServiceDeleteEmbeddingModel:
@@ -850,7 +829,6 @@ class TestRerankModelsServiceGetRerankModels:
mock_provider_result = _create_mock_result([provider]) mock_provider_result = _create_mock_result([provider])
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -891,7 +869,6 @@ class TestRerankModelsServiceGetRerankModel:
mock_provider_result = _create_mock_result([], first_item=provider) mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -951,13 +928,11 @@ class TestRerankModelsServiceCreateRerankModel:
service = RerankModelsService(ap) service = RerankModelsService(ap)
# Execute # Execute
model_uuid = await service.create_rerank_model( model_uuid = await service.create_rerank_model({
{ 'name': 'New Rerank',
'name': 'New Rerank', 'provider_uuid': 'provider-uuid',
'provider_uuid': 'provider-uuid', 'extra_args': {},
'extra_args': {}, })
}
)
# Verify # Verify
assert model_uuid is not None assert model_uuid is not None
@@ -977,13 +952,11 @@ class TestRerankModelsServiceCreateRerankModel:
# Execute & Verify # Execute & Verify
with pytest.raises(Exception, match='provider not found'): with pytest.raises(Exception, match='provider not found'):
await service.create_rerank_model( await service.create_rerank_model({
{ 'name': 'No Provider Rerank',
'name': 'No Provider Rerank', 'provider_uuid': 'nonexistent',
'provider_uuid': 'nonexistent', 'extra_args': {},
'extra_args': {}, })
}
)
class TestRerankModelsServiceDeleteRerankModel: class TestRerankModelsServiceDeleteRerankModel:
@@ -1022,7 +995,9 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
mock_result = _create_mock_result([model1, model2]) mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
)
service = EmbeddingModelsService(ap) service = EmbeddingModelsService(ap)
@@ -1047,7 +1022,9 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
mock_result = _create_mock_result([model1, model2]) mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
)
service = RerankModelsService(ap) service = RerankModelsService(ap)
@@ -1065,10 +1042,14 @@ class TestValidateProviderSupports:
def _make_ap(requester_name: str, support_type): def _make_ap(requester_name: str, support_type):
"""Build a fake ap whose model_mgr resolves a manifest with support_type.""" """Build a fake ap whose model_mgr resolves a manifest with support_type."""
manifest = SimpleNamespace(spec={'support_type': support_type}) manifest = SimpleNamespace(spec={'support_type': support_type})
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name)) runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester=requester_name)
)
model_mgr = SimpleNamespace( model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider}, provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None, get_available_requester_manifest_by_name=lambda name: manifest
if name == requester_name
else None,
) )
return SimpleNamespace(model_mgr=model_mgr) return SimpleNamespace(model_mgr=model_mgr)
@@ -1085,7 +1066,9 @@ class TestValidateProviderSupports:
async def test_allows_when_support_type_missing(self): async def test_allows_when_support_type_missing(self):
# Manifest without support_type must not block (backward compatible) # Manifest without support_type must not block (backward compatible)
manifest = SimpleNamespace(spec={}) manifest = SimpleNamespace(spec={})
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy')) runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester='legacy')
)
model_mgr = SimpleNamespace( model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider}, provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest, get_available_requester_manifest_by_name=lambda name: manifest,

View File

@@ -215,7 +215,13 @@ class TestPipelineServiceCreatePipeline:
ap = SimpleNamespace() ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace() ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace() ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}} ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.pipeline_mgr = SimpleNamespace() ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock() ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace() ap.ver_mgr = SimpleNamespace()
@@ -223,7 +229,9 @@ class TestPipelineServiceCreatePipeline:
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()]) mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
)
service = PipelineService(ap) service = PipelineService(ap)
@@ -250,14 +258,14 @@ class TestPipelineServiceCreatePipeline:
# Mock persistence for insert # Mock persistence for insert
ap.persistence_mgr.execute_async = AsyncMock() ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
)
# Mock the file read for default config - patch at the utils module level # Mock the file read for default config - patch at the utils module level
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}} default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch( with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'}) bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
# Verify # Verify
@@ -278,9 +286,7 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap) service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[]) service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock( service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
)
ap.persistence_mgr.execute_async = AsyncMock() ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock( ap.persistence_mgr.serialize_model = Mock(
@@ -290,9 +296,7 @@ class TestPipelineServiceCreatePipeline:
# Mock the file read # Mock the file read
default_config = {} default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch( with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'Default Pipeline'}, default=True) await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
# Verify - execute was called # Verify - execute was called
@@ -312,12 +316,10 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap) service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[]) service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock( service.get_pipeline = AsyncMock(return_value={
return_value={ 'uuid': 'new-uuid',
'uuid': 'new-uuid', 'extensions_preferences': {},
'extensions_preferences': {}, })
}
)
insert_params = [] insert_params = []
@@ -337,9 +339,7 @@ class TestPipelineServiceCreatePipeline:
default_config = {} default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch( with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'New Pipeline'}) await service.create_pipeline({'name': 'New Pipeline'})
assert len(insert_params) == 1 assert len(insert_params) == 1
@@ -353,7 +353,6 @@ class TestPipelineServiceCreatePipeline:
class _MockResultWithBots: class _MockResultWithBots:
"""Helper class to mock SQLAlchemy result with iterable .all() method.""" """Helper class to mock SQLAlchemy result with iterable .all() method."""
def __init__(self, bots_list): def __init__(self, bots_list):
self._bots_list = bots_list self._bots_list = bots_list
@@ -429,7 +428,6 @@ class TestPipelineServiceUpdatePipeline:
# 1. UPDATE (line 125) - returns Mock (no result needed) # 1. UPDATE (line 125) - returns Mock (no result needed)
# 2. SELECT bots (line 136) - returns bot_result with .all() # 2. SELECT bots (line 136) - returns bot_result with .all()
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -530,7 +528,13 @@ class TestPipelineServiceCopyPipeline:
ap = SimpleNamespace() ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace() ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace() ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}} ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.pipeline_mgr = SimpleNamespace() ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock() ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace() ap.ver_mgr = SimpleNamespace()
@@ -538,12 +542,10 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap) service = PipelineService(ap)
# Mock get_pipelines to return 2 pipelines # Mock get_pipelines to return 2 pipelines
service.get_pipelines = AsyncMock( service.get_pipelines = AsyncMock(return_value=[
return_value=[ {'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-1', 'name': 'Pipeline 1'}, {'uuid': 'uuid-2', 'name': 'Pipeline 2'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'}, ])
]
)
# Execute & Verify # Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'): with pytest.raises(ValueError, match='Maximum number of pipelines'):
@@ -640,7 +642,9 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap) service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[]) service.get_pipelines = AsyncMock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original)) ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False}) ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'copy-uuid', 'is_default': False}
)
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False}) service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
@@ -677,10 +681,11 @@ class TestPipelineServiceUpdatePipelineExtensions:
ap.pipeline_mgr.remove_pipeline = AsyncMock() ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock() ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []}) original_pipeline = _create_mock_pipeline(
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -695,7 +700,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': { 'extensions_preferences': {
'enable_all_plugins': False, 'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}], 'plugins': [{'plugin_uuid': 'plugin-1'}],
}, }
} }
) )
@@ -706,7 +711,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': { 'extensions_preferences': {
'enable_all_plugins': False, 'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}], 'plugins': [{'plugin_uuid': 'plugin-1'}],
}, }
} }
) )
@@ -733,7 +738,6 @@ class TestPipelineServiceUpdatePipelineExtensions:
original_pipeline = _create_mock_pipeline() original_pipeline = _create_mock_pipeline()
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -748,7 +752,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': { 'extensions_preferences': {
'enable_all_mcp_servers': False, 'enable_all_mcp_servers': False,
'mcp_servers': ['mcp-server-1'], 'mcp_servers': ['mcp-server-1'],
}, }
} }
) )
@@ -790,7 +794,6 @@ class TestPipelineServiceUpdatePipelineExtensions:
) )
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1

View File

@@ -245,14 +245,12 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap) service = ModelProviderService(ap)
# Execute # Execute
provider_uuid = await service.create_provider( provider_uuid = await service.create_provider({
{ 'name': 'New Provider',
'name': 'New Provider', 'requester': 'openai',
'requester': 'openai', 'base_url': 'https://api.openai.com',
'base_url': 'https://api.openai.com', 'api_keys': ['key'],
'api_keys': ['key'], })
}
)
# Verify - UUID is generated # Verify - UUID is generated
assert provider_uuid is not None assert provider_uuid is not None
@@ -276,14 +274,12 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap) service = ModelProviderService(ap)
# Execute # Execute
result_uuid = await service.create_provider( result_uuid = await service.create_provider({
{ 'name': 'Runtime Provider',
'name': 'Runtime Provider', 'requester': 'openai',
'requester': 'openai', 'base_url': 'https://api.openai.com',
'base_url': 'https://api.openai.com', 'api_keys': ['key'],
'api_keys': ['key'], })
}
)
# Verify - provider added to runtime dict and UUID generated # Verify - provider added to runtime dict and UUID generated
ap.model_mgr.load_provider.assert_called_once() ap.model_mgr.load_provider.assert_called_once()
@@ -306,13 +302,10 @@ class TestModelProviderServiceUpdateProvider:
service = ModelProviderService(ap) service = ModelProviderService(ap)
# Execute # Execute
await service.update_provider( await service.update_provider('existing-uuid', {
'existing-uuid', 'uuid': 'should-be-removed', # Will be removed
{ 'name': 'Updated Name',
'uuid': 'should-be-removed', # Will be removed })
'name': 'Updated Name',
},
)
# Verify - reload called # Verify - reload called
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid') ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
@@ -371,7 +364,6 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=None) rerank_result.first = Mock(return_value=None)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -404,7 +396,6 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -463,7 +454,6 @@ class TestModelProviderServiceGetProviderModelCounts:
rerank_result.scalar = Mock(return_value=1) rerank_result.scalar = Mock(return_value=1)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -647,7 +637,9 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
await service.update_space_model_provider_api_keys('space-api-key') await service.update_space_model_provider_api_keys('space-api-key')
# Verify - update and reload called for Space provider UUID # Verify - update and reload called for Space provider UUID
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000') ap.model_mgr.reload_provider.assert_called_once_with(
'00000000-0000-0000-0000-000000000000'
)
class TestModelProviderServiceScanProviderModels: class TestModelProviderServiceScanProviderModels:
@@ -803,7 +795,9 @@ class TestModelProviderServiceScanProviderModels:
runtime_provider.token_mgr = Mock() runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token') runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token'] runtime_provider.token_mgr.tokens = ['token']
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported')) runtime_provider.requester.scan_models = AsyncMock(
side_effect=NotImplementedError('scan not supported')
)
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider) ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
service = ModelProviderService(ap) service = ModelProviderService(ap)
@@ -854,7 +848,9 @@ class TestModelProviderServiceScanProviderModels:
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider) ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing LLM model # Mock existing LLM model
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}]) ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
return_value=[{'name': 'Existing Model'}]
)
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[]) ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap) service = ModelProviderService(ap)
@@ -867,4 +863,4 @@ class TestModelProviderServiceScanProviderModels:
assert existing_model['already_added'] is True assert existing_model['already_added'] is True
new_model = next(m for m in result['models'] if m['name'] == 'New Model') new_model = next(m for m in result['models'] if m['name'] == 'New Model')
assert new_model['already_added'] is False assert new_model['already_added'] is False

View File

@@ -393,16 +393,14 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response # Mock HTTP response
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status = 200 mock_response.status = 200
mock_response.json = AsyncMock( mock_response.json = AsyncMock(return_value={
return_value={ 'code': 0,
'code': 0, 'data': {
'data': { 'access_token': 'new_access_token',
'access_token': 'new_access_token', 'refresh_token': 'new_refresh_token',
'refresh_token': 'new_refresh_token', 'expires_in': 3600,
'expires_in': 3600,
},
} }
) })
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock() mock_session_obj = MagicMock()
@@ -431,12 +429,10 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response with error # Mock HTTP response with error
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status = 200 mock_response.status = 200
mock_response.json = AsyncMock( mock_response.json = AsyncMock(return_value={
return_value={ 'code': 1,
'code': 1, 'msg': 'Invalid refresh token',
'msg': 'Invalid refresh token', })
}
)
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}') mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
@@ -493,16 +489,14 @@ class TestSpaceServiceExchangeOAuthCode:
# Mock HTTP response # Mock HTTP response
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status = 200 mock_response.status = 200
mock_response.json = AsyncMock( mock_response.json = AsyncMock(return_value={
return_value={ 'code': 0,
'code': 0, 'data': {
'data': { 'access_token': 'new_access_token',
'access_token': 'new_access_token', 'refresh_token': 'new_refresh_token',
'refresh_token': 'new_refresh_token', 'expires_in': 3600,
'expires_in': 3600,
},
} }
) })
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock() mock_session_obj = MagicMock()
@@ -561,15 +555,13 @@ class TestSpaceServiceGetUserInfoRaw:
# Mock HTTP response # Mock HTTP response
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status = 200 mock_response.status = 200
mock_response.json = AsyncMock( mock_response.json = AsyncMock(return_value={
return_value={ 'code': 0,
'code': 0, 'data': {
'data': { 'email': 'test@example.com',
'email': 'test@example.com', 'credits': 100,
'credits': 100,
},
} }
) })
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock() mock_session_obj = MagicMock()
@@ -677,29 +669,27 @@ class TestSpaceServiceGetModels:
# Mock HTTP response with proper model data matching SpaceModel schema # Mock HTTP response with proper model data matching SpaceModel schema
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status = 200 mock_response.status = 200
mock_response.json = AsyncMock( mock_response.json = AsyncMock(return_value={
return_value={ 'code': 0,
'code': 0, 'data': {
'data': { 'models': [
'models': [ {
{ 'uuid': 'uuid-1',
'uuid': 'uuid-1', 'model_id': 'model-1',
'model_id': 'model-1', 'provider': 'provider-1',
'provider': 'provider-1', 'category': 'chat',
'category': 'chat', 'status': 'active',
'status': 'active', },
}, {
{ 'uuid': 'uuid-2',
'uuid': 'uuid-2', 'model_id': 'model-2',
'model_id': 'model-2', 'provider': 'provider-2',
'provider': 'provider-2', 'category': 'chat',
'category': 'chat', 'status': 'active',
'status': 'active', },
}, ]
]
},
} }
) })
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock() mock_session_obj = MagicMock()
@@ -785,4 +775,4 @@ class TestSpaceServiceCreditsCache:
# Verify - cache updated # Verify - cache updated
assert result == 500 assert result == 500
assert 'test@example.com' in service._credits_cache assert 'test@example.com' in service._credits_cache
assert service._credits_cache['test@example.com'][0] == 500 assert service._credits_cache['test@example.com'][0] == 500

View File

@@ -495,7 +495,6 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user # First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0 call_count = 0
async def mock_get_by_space_uuid(uuid): async def mock_get_by_space_uuid(uuid):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -566,7 +565,6 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user # First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0 call_count = 0
async def mock_get_by_space_uuid(uuid): async def mock_get_by_space_uuid(uuid):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -607,4 +605,4 @@ class TestUserServiceCreateUserLock:
# Verify lock exists # Verify lock exists
assert hasattr(service, '_create_user_lock') assert hasattr(service, '_create_user_lock')
assert service._create_user_lock is not None assert service._create_user_lock is not None

View File

@@ -132,7 +132,6 @@ class TestWebhookServiceCreateWebhook:
# execute_async returns different results # execute_async returns different results
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -182,7 +181,6 @@ class TestWebhookServiceCreateWebhook:
) )
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -219,7 +217,6 @@ class TestWebhookServiceCreateWebhook:
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False) created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
call_count = 0 call_count = 0
async def mock_execute(query): async def mock_execute(query):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@@ -228,7 +225,9 @@ class TestWebhookServiceCreateWebhook:
return _create_mock_result(first_item=created_webhook) return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False}) ap.persistence_mgr.serialize_model = Mock(
return_value={'id': 1, 'enabled': False}
)
service = WebhookService(ap) service = WebhookService(ap)
@@ -504,4 +503,4 @@ class TestWebhookServiceGetEnabledWebhooks:
result = await service.get_enabled_webhooks() result = await service.get_enabled_webhooks()
# Verify - should be empty (SQL would filter disabled) # Verify - should be empty (SQL would filter disabled)
assert result == [] assert result == []

View File

@@ -407,9 +407,7 @@ def test_box_service_forced_template_ignores_pipeline_config():
launcher_type='person', launcher_type='person',
launcher_id='test_user', launcher_id='test_user',
sender_id='test_user', sender_id='test_user',
pipeline_config={ pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
},
) )
assert service.resolve_box_session_id(query) == 'global' assert service.resolve_box_session_id(query) == 'global'
@@ -1529,7 +1527,9 @@ class TestBuildSkillExtraMounts:
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'}, {'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
] ]
# No skill is dropped, so no "missing" warning should be logged. # No skill is dropped, so no "missing" warning should be logged.
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list) assert not any(
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
)
def test_skips_skill_with_empty_package_root(self): def test_skips_skill_with_empty_package_root(self):
logger = Mock() logger = Mock()

View File

@@ -1 +1 @@
# Unit tests for command module # Unit tests for command module

View File

@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
# Should yield CommandNotFoundError (no such command registered) # Should yield CommandNotFoundError (no such command registered)
assert len(results) == 1 assert len(results) == 1
assert results[0].error is not None assert results[0].error is not None

View File

@@ -197,7 +197,6 @@ class TestCommandOperatorBase:
op = TestOperator(None) op = TestOperator(None)
# Should not raise # Should not raise
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(op.initialize()) asyncio.get_event_loop().run_until_complete(op.initialize())
def test_execute_is_abstract(self): def test_execute_is_abstract(self):
@@ -300,4 +299,4 @@ class TestMultipleOperators:
yield None yield None
assert AdminOperator.lowest_privilege == 2 assert AdminOperator.lowest_privilege == 2
assert SubOperator.lowest_privilege == 1 assert SubOperator.lowest_privilege == 1

View File

@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_yaml_loads(self, tmp_path): async def test_valid_yaml_loads(self, tmp_path):
"""Valid YAML config should load correctly.""" """Valid YAML config should load correctly."""
config_file = tmp_path / 'test_config.yaml' config_file = tmp_path / "test_config.yaml"
# Write valid YAML # Write valid YAML
config_file.write_text(""" config_file.write_text("""
@@ -51,7 +51,7 @@ settings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_yaml_raises_error(self, tmp_path): async def test_invalid_yaml_raises_error(self, tmp_path):
"""Invalid YAML should raise clear error.""" """Invalid YAML should raise clear error."""
config_file = tmp_path / 'invalid.yaml' config_file = tmp_path / "invalid.yaml"
# Write invalid YAML (unclosed bracket) # Write invalid YAML (unclosed bracket)
config_file.write_text(""" config_file.write_text("""
@@ -67,13 +67,13 @@ settings:
template_data={'name': 'default'}, template_data={'name': 'default'},
) )
with pytest.raises(Exception, match='Syntax error'): with pytest.raises(Exception, match="Syntax error"):
await yaml_file.load(completion=False) await yaml_file.load(completion=False)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_config_creates_from_template(self, tmp_path): async def test_missing_config_creates_from_template(self, tmp_path):
"""Missing config file should be created from template.""" """Missing config file should be created from template."""
config_file = tmp_path / 'new_config.yaml' config_file = tmp_path / "new_config.yaml"
# File doesn't exist yet # File doesn't exist yet
assert not config_file.exists() assert not config_file.exists()
@@ -92,7 +92,7 @@ settings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_template_completion(self, tmp_path): async def test_template_completion(self, tmp_path):
"""Config should be completed with template defaults.""" """Config should be completed with template defaults."""
config_file = tmp_path / 'partial.yaml' config_file = tmp_path / "partial.yaml"
# Write partial config missing some template keys # Write partial config missing some template keys
config_file.write_text(""" config_file.write_text("""
@@ -115,7 +115,7 @@ name: custom_name
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_yaml_save(self, tmp_path): async def test_yaml_save(self, tmp_path):
"""YAML config can be saved.""" """YAML config can be saved."""
config_file = tmp_path / 'save_test.yaml' config_file = tmp_path / "save_test.yaml"
yaml_file = YAMLConfigFile( yaml_file = YAMLConfigFile(
str(config_file), str(config_file),
@@ -131,7 +131,7 @@ name: custom_name
def test_yaml_save_sync(self, tmp_path): def test_yaml_save_sync(self, tmp_path):
"""YAML config can be saved synchronously.""" """YAML config can be saved synchronously."""
config_file = tmp_path / 'sync_save.yaml' config_file = tmp_path / "sync_save.yaml"
yaml_file = YAMLConfigFile( yaml_file = YAMLConfigFile(
str(config_file), str(config_file),
@@ -151,18 +151,14 @@ class TestJSONConfigFile:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_json_loads(self, tmp_path): async def test_valid_json_loads(self, tmp_path):
"""Valid JSON config should load correctly.""" """Valid JSON config should load correctly."""
config_file = tmp_path / 'test_config.json' config_file = tmp_path / "test_config.json"
# Write valid JSON # Write valid JSON
config_file.write_text( config_file.write_text(json.dumps({
json.dumps( 'name': 'json_app',
{ 'version': '1.0',
'name': 'json_app', 'settings': {'debug': True, 'port': 8080},
'version': '1.0', }))
'settings': {'debug': True, 'port': 8080},
}
)
)
json_file = JSONConfigFile( json_file = JSONConfigFile(
str(config_file), str(config_file),
@@ -178,7 +174,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_json_raises_error(self, tmp_path): async def test_invalid_json_raises_error(self, tmp_path):
"""Invalid JSON should raise clear error.""" """Invalid JSON should raise clear error."""
config_file = tmp_path / 'invalid.json' config_file = tmp_path / "invalid.json"
# Write invalid JSON (missing closing brace) # Write invalid JSON (missing closing brace)
config_file.write_text('{"name": "test", "unclosed": ') config_file.write_text('{"name": "test", "unclosed": ')
@@ -188,13 +184,13 @@ class TestJSONConfigFile:
template_data={'name': 'default'}, template_data={'name': 'default'},
) )
with pytest.raises(Exception, match='Syntax error'): with pytest.raises(Exception, match="Syntax error"):
await json_file.load(completion=False) await json_file.load(completion=False)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_missing_json_creates_from_template(self, tmp_path): async def test_missing_json_creates_from_template(self, tmp_path):
"""Missing JSON file should be created from template.""" """Missing JSON file should be created from template."""
config_file = tmp_path / 'new_config.json' config_file = tmp_path / "new_config.json"
json_file = JSONConfigFile( json_file = JSONConfigFile(
str(config_file), str(config_file),
@@ -209,7 +205,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_json_save(self, tmp_path): async def test_json_save(self, tmp_path):
"""JSON config can be saved.""" """JSON config can be saved."""
config_file = tmp_path / 'save_test.json' config_file = tmp_path / "save_test.json"
json_file = JSONConfigFile( json_file = JSONConfigFile(
str(config_file), str(config_file),
@@ -230,7 +226,7 @@ class TestConfigManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_config_manager_load(self, tmp_path): async def test_config_manager_load(self, tmp_path):
"""ConfigManager loads config correctly.""" """ConfigManager loads config correctly."""
config_file = tmp_path / 'manager_test.yaml' config_file = tmp_path / "manager_test.yaml"
config_file.write_text('name: managed_app\nversion: "1.0"\n') config_file.write_text('name: managed_app\nversion: "1.0"\n')
yaml_file = YAMLConfigFile( yaml_file = YAMLConfigFile(
@@ -247,7 +243,7 @@ class TestConfigManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_config_manager_dump(self, tmp_path): async def test_config_manager_dump(self, tmp_path):
"""ConfigManager can dump config.""" """ConfigManager can dump config."""
config_file = tmp_path / 'dump_test.yaml' config_file = tmp_path / "dump_test.yaml"
yaml_file = YAMLConfigFile( yaml_file = YAMLConfigFile(
str(config_file), str(config_file),
@@ -264,7 +260,7 @@ class TestConfigManager:
def test_config_manager_dump_sync(self, tmp_path): def test_config_manager_dump_sync(self, tmp_path):
"""ConfigManager can dump config synchronously.""" """ConfigManager can dump config synchronously."""
config_file = tmp_path / 'sync_dump.yaml' config_file = tmp_path / "sync_dump.yaml"
yaml_file = YAMLConfigFile( yaml_file = YAMLConfigFile(
str(config_file), str(config_file),
@@ -284,7 +280,7 @@ class TestConfigExists:
def test_yaml_exists_true(self, tmp_path): def test_yaml_exists_true(self, tmp_path):
"""exists() returns True for existing file.""" """exists() returns True for existing file."""
config_file = tmp_path / 'exists.yaml' config_file = tmp_path / "exists.yaml"
config_file.write_text('name: test') config_file.write_text('name: test')
yaml_file = YAMLConfigFile(str(config_file), template_data={}) yaml_file = YAMLConfigFile(str(config_file), template_data={})
@@ -292,14 +288,14 @@ class TestConfigExists:
def test_yaml_exists_false(self, tmp_path): def test_yaml_exists_false(self, tmp_path):
"""exists() returns False for missing file.""" """exists() returns False for missing file."""
config_file = tmp_path / 'missing.yaml' config_file = tmp_path / "missing.yaml"
yaml_file = YAMLConfigFile(str(config_file), template_data={}) yaml_file = YAMLConfigFile(str(config_file), template_data={})
assert yaml_file.exists() is False assert yaml_file.exists() is False
def test_json_exists_true(self, tmp_path): def test_json_exists_true(self, tmp_path):
"""exists() returns True for existing JSON file.""" """exists() returns True for existing JSON file."""
config_file = tmp_path / 'exists.json' config_file = tmp_path / "exists.json"
config_file.write_text('{}') config_file.write_text('{}')
json_file = JSONConfigFile(str(config_file), template_data={}) json_file = JSONConfigFile(str(config_file), template_data={})
@@ -307,7 +303,7 @@ class TestConfigExists:
def test_json_exists_false(self, tmp_path): def test_json_exists_false(self, tmp_path):
"""exists() returns False for missing JSON file.""" """exists() returns False for missing JSON file."""
config_file = tmp_path / 'missing.json' config_file = tmp_path / "missing.json"
json_file = JSONConfigFile(str(config_file), template_data={}) json_file = JSONConfigFile(str(config_file), template_data={})
assert json_file.exists() is False assert json_file.exists() is False

View File

@@ -1 +1 @@
"""Core module unit tests.""" """Core module unit tests."""

View File

@@ -4,7 +4,6 @@ Tests cover:
- _get_positive_int_config() validation - _get_positive_int_config() validation
- _get_positive_float_config() validation - _get_positive_float_config() validation
""" """
from __future__ import annotations from __future__ import annotations
from unittest.mock import Mock from unittest.mock import Mock
@@ -189,4 +188,4 @@ class TestGetPositiveFloatConfig:
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config') result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
assert result == 1.5 assert result == 1.5
mock_logger.warning.assert_called_once() mock_logger.warning.assert_called_once()

View File

@@ -27,7 +27,6 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps from langbot.pkg.core.bootutils.deps import check_deps
import asyncio import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps()) result = asyncio.get_event_loop().run_until_complete(check_deps())
assert result == [] assert result == []
@@ -47,7 +46,6 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps from langbot.pkg.core.bootutils.deps import check_deps
import asyncio import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps()) result = asyncio.get_event_loop().run_until_complete(check_deps())
assert 'requests' in result assert 'requests' in result
@@ -63,7 +61,6 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps, required_deps from langbot.pkg.core.bootutils.deps import check_deps, required_deps
import asyncio import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps()) result = asyncio.get_event_loop().run_until_complete(check_deps())
# Should include all required_deps keys # Should include all required_deps keys
@@ -110,7 +107,6 @@ class TestPrecheckPluginDeps:
with patch('os.path.exists', return_value=False): with patch('os.path.exists', return_value=False):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install: with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps()) asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_not_called() mock_install.assert_not_called()
@@ -133,7 +129,6 @@ class TestPrecheckPluginDeps:
with patch('os.listdir', side_effect=mock_listdir): with patch('os.listdir', side_effect=mock_listdir):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install: with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps()) asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[]) mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])

View File

@@ -7,7 +7,6 @@ Tests cover:
- Dict type skipping - Dict type skipping
- Missing key creation - Missing key creation
""" """
from __future__ import annotations from __future__ import annotations
import os import os
@@ -249,8 +248,15 @@ class TestApplyEnvOverridesToConfig:
"""Test multiple env vars applied in order.""" """Test multiple env vars applied in order."""
load_config = get_load_config_module() load_config = get_load_config_module()
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}} cfg = {
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'} 'system': {'name': 'default', 'enable': True},
'concurrency': {'pipeline': 5}
}
env = {
'SYSTEM__NAME': 'custom',
'SYSTEM__ENABLE': 'false',
'CONCURRENCY__PIPELINE': '10'
}
with patch.dict(os.environ, env, clear=True): with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg) result = load_config._apply_env_overrides_to_config(cfg)
@@ -281,4 +287,4 @@ class TestApplyEnvOverridesToConfig:
with patch.dict(os.environ, env, clear=True): with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg) result = load_config._apply_env_overrides_to_config(cfg)
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com' assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'

View File

@@ -175,4 +175,4 @@ class TestPreregisteredStages:
pass pass
for key in preregistered_stages: for key in preregistered_stages:
assert isinstance(key, str) assert isinstance(key, str)

View File

@@ -7,7 +7,6 @@ Tests cover:
Note: Uses import_isolation to break circular import chains. Note: Uses import_isolation to break circular import chains.
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -20,17 +19,15 @@ from typing import Generator
class MockLifecycleControlScopeEnum: class MockLifecycleControlScopeEnum:
"""Mock enum value for LifecycleControlScope with .value attribute.""" """Mock enum value for LifecycleControlScope with .value attribute."""
def __init__(self, value: str): def __init__(self, value: str):
self.value = value self.value = value
def __repr__(self): def __repr__(self):
return f'LifecycleControlScope.{self.value.upper()}' return f"LifecycleControlScope.{self.value.upper()}"
class MockLifecycleControlScope: class MockLifecycleControlScope:
"""Mock enum for LifecycleControlScope.""" """Mock enum for LifecycleControlScope."""
APPLICATION = MockLifecycleControlScopeEnum('application') APPLICATION = MockLifecycleControlScopeEnum('application')
PLATFORM = MockLifecycleControlScopeEnum('platform') PLATFORM = MockLifecycleControlScopeEnum('platform')
PIPELINE = MockLifecycleControlScopeEnum('pipeline') PIPELINE = MockLifecycleControlScopeEnum('pipeline')
@@ -43,17 +40,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
# Mock modules that cause circular imports # Mock modules that cause circular imports
mock_entities = MagicMock() mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope mock_entities.LifecycleControlScope = MockLifecycleControlScope
mock_app = MagicMock() mock_app = MagicMock()
mock_importutil = MagicMock() mock_importutil = MagicMock()
mock_importutil.import_modules_in_pkg = lambda pkg: None mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None mock_importutil.import_modules_in_pkgs = lambda pkgs: None
mock_http_controller = MagicMock() mock_http_controller = MagicMock()
mock_rag_mgr = MagicMock() mock_rag_mgr = MagicMock()
mocks = { mocks = {
'langbot.pkg.core.entities': mock_entities, 'langbot.pkg.core.entities': mock_entities,
'langbot.pkg.core.app': mock_app, 'langbot.pkg.core.app': mock_app,
@@ -61,26 +58,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr, 'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
'langbot.pkg.utils.importutil': mock_importutil, 'langbot.pkg.utils.importutil': mock_importutil,
} }
# Save original state # Save original state
saved = {} saved = {}
for name in mocks: for name in mocks:
if name in sys.modules: if name in sys.modules:
saved[name] = sys.modules[name] saved[name] = sys.modules[name]
# Clear taskmgr to force re-import # Clear taskmgr to force re-import
taskmgr_name = 'langbot.pkg.core.taskmgr' taskmgr_name = 'langbot.pkg.core.taskmgr'
if taskmgr_name in sys.modules: if taskmgr_name in sys.modules:
saved[taskmgr_name] = sys.modules[taskmgr_name] saved[taskmgr_name] = sys.modules[taskmgr_name]
try: try:
# Apply mocks # Apply mocks
for name, module in mocks.items(): for name, module in mocks.items():
sys.modules[name] = module sys.modules[name] = module
# Clear taskmgr # Clear taskmgr
sys.modules.pop(taskmgr_name, None) sys.modules.pop(taskmgr_name, None)
yield yield
finally: finally:
# Restore # Restore
@@ -89,7 +86,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
sys.modules[name] = saved[name] sys.modules[name] = saved[name]
else: else:
sys.modules.pop(name, None) sys.modules.pop(name, None)
if taskmgr_name in saved: if taskmgr_name in saved:
sys.modules[taskmgr_name] = saved[taskmgr_name] sys.modules[taskmgr_name] = saved[taskmgr_name]
else: else:
@@ -100,7 +97,6 @@ def get_taskmgr_classes():
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes.""" """Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
with isolated_taskmgr_import(): with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
return TaskContext, TaskWrapper, AsyncTaskManager return TaskContext, TaskWrapper, AsyncTaskManager
@@ -198,10 +194,9 @@ class TestTaskContext:
"""Test TaskContext.placeholder() returns singleton.""" """Test TaskContext.placeholder() returns singleton."""
with isolated_taskmgr_import(): with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext from langbot.pkg.core.taskmgr import TaskContext
# Reset global placeholder # Reset global placeholder
import langbot.pkg.core.taskmgr as taskmgr_module import langbot.pkg.core.taskmgr as taskmgr_module
taskmgr_module.placeholder_context = None taskmgr_module.placeholder_context = None
ctx1 = TaskContext.placeholder() ctx1 = TaskContext.placeholder()
@@ -274,8 +269,7 @@ class TestTaskWrapper:
return 'result' return 'result'
wrapper = TaskWrapper( wrapper = TaskWrapper(
mock_app, mock_app, immediate_coro(),
immediate_coro(),
name='test_task', name='test_task',
label='Test Task', label='Test Task',
) )
@@ -420,7 +414,7 @@ class TestAsyncTaskManager:
async def test_cancel_by_scope(self): async def test_cancel_by_scope(self):
"""Test cancel_by_scope cancels matching tasks.""" """Test cancel_by_scope cancels matching tasks."""
_, _, AsyncTaskManager = get_taskmgr_classes() _, _, AsyncTaskManager = get_taskmgr_classes()
mock_app = create_mock_app() mock_app = create_mock_app()
manager = AsyncTaskManager(mock_app) manager = AsyncTaskManager(mock_app)
@@ -428,10 +422,16 @@ class TestAsyncTaskManager:
await asyncio.sleep(10) await asyncio.sleep(10)
# Create task with APPLICATION scope # Create task with APPLICATION scope
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION]) w1 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.APPLICATION]
)
# Create task with different scope # Create task with different scope
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE]) w2 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.PIPELINE]
)
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION) manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)

View File

@@ -15,68 +15,68 @@ class TestI18nString:
def test_create_with_english_only(self): def test_create_with_english_only(self):
"""Create I18nString with only English.""" """Create I18nString with only English."""
i18n = I18nString(en_US='Hello') i18n = I18nString(en_US="Hello")
assert i18n.en_US == 'Hello' assert i18n.en_US == "Hello"
assert i18n.zh_Hans is None assert i18n.zh_Hans is None
def test_create_with_multiple_languages(self): def test_create_with_multiple_languages(self):
"""Create I18nString with multiple languages.""" """Create I18nString with multiple languages."""
i18n = I18nString( i18n = I18nString(
en_US='Hello', en_US="Hello",
zh_Hans='你好', zh_Hans="你好",
zh_Hant='你好', zh_Hant="你好",
ja_JP='こんにちは', ja_JP="こんにちは",
) )
assert i18n.en_US == 'Hello' assert i18n.en_US == "Hello"
assert i18n.zh_Hans == '你好' assert i18n.zh_Hans == "你好"
assert i18n.zh_Hant == '你好' assert i18n.zh_Hant == "你好"
assert i18n.ja_JP == 'こんにちは' assert i18n.ja_JP == "こんにちは"
def test_to_dict_with_english_only(self): def test_to_dict_with_english_only(self):
"""to_dict returns only non-None fields.""" """to_dict returns only non-None fields."""
i18n = I18nString(en_US='Hello') i18n = I18nString(en_US="Hello")
result = i18n.to_dict() result = i18n.to_dict()
assert result == {'en_US': 'Hello'} assert result == {"en_US": "Hello"}
def test_to_dict_with_multiple_languages(self): def test_to_dict_with_multiple_languages(self):
"""to_dict returns all non-None fields.""" """to_dict returns all non-None fields."""
i18n = I18nString( i18n = I18nString(
en_US='Hello', en_US="Hello",
zh_Hans='你好', zh_Hans="你好",
) )
result = i18n.to_dict() result = i18n.to_dict()
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'} assert result == {"en_US": "Hello", "zh_Hans": "你好"}
def test_to_dict_excludes_none(self): def test_to_dict_excludes_none(self):
"""to_dict excludes None values.""" """to_dict excludes None values."""
i18n = I18nString( i18n = I18nString(
en_US='Hello', en_US="Hello",
zh_Hans=None, zh_Hans=None,
ja_JP='こんにちは', ja_JP="こんにちは",
) )
result = i18n.to_dict() result = i18n.to_dict()
assert 'zh_Hans' not in result assert "zh_Hans" not in result
assert 'en_US' in result assert "en_US" in result
assert 'ja_JP' in result assert "ja_JP" in result
def test_to_dict_all_languages(self): def test_to_dict_all_languages(self):
"""to_dict with all supported languages.""" """to_dict with all supported languages."""
i18n = I18nString( i18n = I18nString(
en_US='Hello', en_US="Hello",
zh_Hans='你好', zh_Hans="你好",
zh_Hant='你好', zh_Hant="你好",
ja_JP='こんにちは', ja_JP="こんにちは",
th_TH='สวัสดี', th_TH="สวัสดี",
vi_VN='Xin chào', vi_VN="Xin chào",
es_ES='Hola', es_ES="Hola",
) )
result = i18n.to_dict() result = i18n.to_dict()
@@ -92,30 +92,30 @@ class TestMetadata:
from langbot.pkg.discover.engine import I18nString from langbot.pkg.discover.engine import I18nString
metadata = Metadata( metadata = Metadata(
name='test-component', name="test-component",
label=I18nString(en_US='Test Component'), label=I18nString(en_US="Test Component"),
) )
assert metadata.name == 'test-component' assert metadata.name == "test-component"
assert metadata.label.en_US == 'Test Component' assert metadata.label.en_US == "Test Component"
def test_create_with_all_fields(self): def test_create_with_all_fields(self):
"""Create Metadata with all optional fields.""" """Create Metadata with all optional fields."""
from langbot.pkg.discover.engine import I18nString from langbot.pkg.discover.engine import I18nString
metadata = Metadata( metadata = Metadata(
name='test-component', name="test-component",
label=I18nString(en_US='Test'), label=I18nString(en_US="Test"),
description=I18nString(en_US='A test component'), description=I18nString(en_US="A test component"),
version='1.0.0', version="1.0.0",
icon='test-icon', icon="test-icon",
author='Test Author', author="Test Author",
repository='https://github.com/test/repo', repository="https://github.com/test/repo",
) )
assert metadata.version == '1.0.0' assert metadata.version == "1.0.0"
assert metadata.icon == 'test-icon' assert metadata.icon == "test-icon"
assert metadata.author == 'Test Author' assert metadata.author == "Test Author"
class TestComponentManifest: class TestComponentManifest:

View File

@@ -7,7 +7,6 @@ Tests cover:
Note: Uses import isolation to break circular import chains. Note: Uses import isolation to break circular import chains.
""" """
from __future__ import annotations from __future__ import annotations
import sys import sys
@@ -87,7 +86,6 @@ def get_database_module():
"""Get database module with import isolation.""" """Get database module with import isolation."""
with isolated_database_import(): with isolated_database_import():
from langbot.pkg.persistence import database from langbot.pkg.persistence import database
return database return database
@@ -200,4 +198,4 @@ class TestManagerClassDecorator:
# Create instance to test method (with mock app) # Create instance to test method (with mock app)
mock_app = Mock() mock_app = Mock()
instance = ManagerWithMethods(mock_app) instance = ManagerWithMethods(mock_app)
assert instance.custom_method() == 'test_value' assert instance.custom_method() == 'test_value'

View File

@@ -4,7 +4,6 @@ Tests cover:
- execute_async() with mock database - execute_async() with mock database
- get_db_engine() with mock database manager - get_db_engine() with mock database manager
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -86,7 +85,7 @@ class TestExecuteAsync:
mock_db.get_engine = Mock(return_value=mock_engine) mock_db.get_engine = Mock(return_value=mock_engine)
mgr.db = mock_db mgr.db = mock_db
result = await mgr.execute_async(sqlalchemy.text('SELECT 1')) result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
# Verify result is the same object returned by execute # Verify result is the same object returned by execute
assert result is mock_result assert result is mock_result
@@ -153,4 +152,4 @@ class TestSerializeModelEdgeCases:
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name']) result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
# Result should be empty dict when all columns masked # Result should be empty dict when all columns masked
assert result == {} assert result == {}

View File

@@ -5,7 +5,6 @@ Tests cover:
- datetime conversion to isoformat - datetime conversion to isoformat
- masked_columns exclusion - masked_columns exclusion
""" """
from __future__ import annotations from __future__ import annotations
import datetime import datetime

View File

@@ -49,7 +49,7 @@ class TestPendingMessage:
"""PendingMessage should be created with correct fields.""" """PendingMessage should be created with correct fields."""
aggregator = get_aggregator_module() aggregator = get_aggregator_module()
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -88,7 +88,7 @@ class TestSessionBuffer:
"""SessionBuffer should accept initial messages.""" """SessionBuffer should accept initial messages."""
aggregator = get_aggregator_module() aggregator = get_aggregator_module()
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain1 = text_chain('hello') chain1 = text_chain("hello")
chain2 = text_chain('world') chain2 = text_chain("world")
event = friend_message_event(chain1) event = friend_message_event(chain1)
adapter = mock_adapter() adapter = mock_adapter()
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
# Should contain both messages with separator # Should contain both messages with separator
merged_str = str(merged.message_chain) merged_str = str(merged.message_chain)
assert 'hello' in merged_str assert "hello" in merged_str
assert 'world' in merged_str assert "world" in merged_str
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self): def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed.""" """Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain1 = text_chain('first') chain1 = text_chain("first")
chain2 = text_chain('second') chain2 = text_chain("second")
event = friend_message_event(chain1) event = friend_message_event(chain1)
adapter = mock_adapter() adapter = mock_adapter()
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
app = make_aggregator_app() app = make_aggregator_app()
agg = aggregator.MessageAggregator(app) agg = aggregator.MessageAggregator(app)
chain = text_chain('hello') chain = text_chain("hello")
event = friend_message_event(chain) event = friend_message_event(chain)
adapter = mock_adapter() adapter = mock_adapter()

View File

@@ -15,7 +15,6 @@ from tests.factories import FakeApp
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== # ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def mock_circular_import_chain(): def mock_circular_import_chain():
""" """
@@ -37,11 +36,9 @@ def mock_circular_import_chain():
# Create a default runner that yields a simple response # Create a default runner that yields a simple response
class DefaultRunner: class DefaultRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
yield Message(role='assistant', content='fake response') yield Message(role='assistant', content='fake response')
@@ -73,12 +70,9 @@ def mock_event_ctx():
@pytest.fixture @pytest.fixture
def set_runner(): def set_runner():
"""Factory fixture to set a custom runner for tests.""" """Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class): def _set_runner(runner_class):
import sys import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class] sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner return _set_runner
@@ -93,7 +87,6 @@ def get_chat_handler():
global _chat_handler_module global _chat_handler_module
if _chat_handler_module is None: if _chat_handler_module is None:
from importlib import import_module from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat') _chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module return _chat_handler_module
@@ -103,14 +96,12 @@ def get_entities():
global _entities_module global _entities_module
if _entities_module is None: if _entities_module is None:
from importlib import import_module from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities') _entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module return _entities_module
# ============== REAL ChatMessageHandler Tests ============== # ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain') @pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal: class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class.""" """Tests for real ChatMessageHandler class."""
@@ -197,11 +188,9 @@ class TestChatMessageHandlerReal:
class QuickRunner: class QuickRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
yield Message(role='assistant', content='ok') yield Message(role='assistant', content='ok')
@@ -233,11 +222,9 @@ class TestChatMessageHandlerReal:
class SingleRunner: class SingleRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
yield Message(role='assistant', content='response') yield Message(role='assistant', content='response')
@@ -275,11 +262,9 @@ class TestChatHandlerStreaming:
class StreamRunner: class StreamRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False) yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True) yield MessageChunk(role='assistant', content=' World', is_final=True)
@@ -318,19 +303,14 @@ class TestChatHandlerExceptions:
query.pipeline_config = { query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}}, 'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': { 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
} }
class FailingRunner: class FailingRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
raise ValueError('API error') raise ValueError('API error')
yield yield
@@ -366,19 +346,14 @@ class TestChatHandlerExceptions:
query.pipeline_config = { query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}}, 'output': {'misc': {'exception-handling': 'show-error'}},
'ai': { 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
} }
class ErrorRunner: class ErrorRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
raise ValueError('Custom error') raise ValueError('Custom error')
yield yield
@@ -411,19 +386,14 @@ class TestChatHandlerExceptions:
query.pipeline_config = { query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}}, 'output': {'misc': {'exception-handling': 'hide'}},
'ai': { 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
} }
class HideErrorRunner: class HideErrorRunner:
name = 'local-agent' name = 'local-agent'
def __init__(self, app, config): def __init__(self, app, config):
self.app = app self.app = app
self.config = config self.config = config
async def run(self, query): async def run(self, query):
raise RuntimeError('hidden') raise RuntimeError('hidden')
yield yield
@@ -463,4 +433,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler() chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app) handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line') result = handler.cut_str('first line\nsecond line')
assert '...' in result assert '...' in result

View File

@@ -67,11 +67,7 @@ def make_pipeline_config(**overrides):
for key, value in overrides.items(): for key, value in overrides.items():
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict): if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
for sub_key, sub_value in value.items(): for sub_key, sub_value in value.items():
if ( if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
sub_key in base_config[key]
and isinstance(base_config[key][sub_key], dict)
and isinstance(sub_value, dict)
):
base_config[key][sub_key].update(sub_value) base_config[key][sub_key].update(sub_value)
else: else:
base_config[key][sub_key] = sub_value base_config[key][sub_key] = sub_value
@@ -145,7 +141,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello world') query = text_query("hello world")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -167,7 +163,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
# Empty message chain # Empty message chain
query = text_query('') query = text_query("")
query.message_chain = platform_message.MessageChain([]) query.message_chain = platform_message.MessageChain([])
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
@@ -189,7 +185,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query(' ') # Only whitespace query = text_query(" ") # Only whitespace
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -238,7 +234,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello world') query = text_query("hello world")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -270,7 +266,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('/help me') query = text_query("/help me")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -298,7 +294,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('http://example.com') query = text_query("http://example.com")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -326,7 +322,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('normal message') query = text_query("normal message")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -347,7 +343,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('/help me') query = text_query("/help me")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')
@@ -372,10 +368,12 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
# Add a response message # Add a response message
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')] query.resp_messages = [
provider_message.Message(role='assistant', content='Hello back!')
]
result = await stage.process(query, 'PostContentFilterStage') result = await stage.process(query, 'PostContentFilterStage')
@@ -400,9 +398,11 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_messages = [provider_message.Message(role='assistant', content='Response')] query.resp_messages = [
provider_message.Message(role='assistant', content='Response')
]
result = await stage.process(query, 'PostContentFilterStage') result = await stage.process(query, 'PostContentFilterStage')
@@ -422,7 +422,7 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
# Non-string content - use model_construct to bypass validation # Non-string content - use model_construct to bypass validation
# The actual content type could be a list of ContentElement objects # The actual content type could be a list of ContentElement objects
@@ -450,9 +450,11 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_messages = [provider_message.Message(role='assistant', content='')] query.resp_messages = [
provider_message.Message(role='assistant', content='')
]
result = await stage.process(query, 'PostContentFilterStage') result = await stage.process(query, 'PostContentFilterStage')
@@ -474,7 +476,7 @@ class TestContentFilterStageInvalidName:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
with pytest.raises(ValueError, match='未知的 stage_inst_name'): with pytest.raises(ValueError, match='未知的 stage_inst_name'):
@@ -504,7 +506,7 @@ class TestContentIgnoreFilterDirect:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('normal message without prefix') query = text_query("normal message without prefix")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage') result = await stage.process(query, 'PreContentFilterStage')

View File

@@ -15,7 +15,6 @@ from tests.factories import FakeApp, command_query
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== # ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def mock_circular_import_chain(): def mock_circular_import_chain():
""" """
@@ -57,7 +56,6 @@ def mock_event_ctx():
@pytest.fixture @pytest.fixture
def mock_execute_factory(): def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators.""" """Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute( def _create_execute(
text: str | None = 'ok', text: str | None = 'ok',
error: str | None = None, error: str | None = None,
@@ -73,9 +71,7 @@ def mock_execute_factory():
ret.image_base64 = image_base64 ret.image_base64 = image_base64
ret.file_url = file_url ret.file_url = file_url
yield ret yield ret
return mock_execute return mock_execute
return _create_execute return _create_execute
@@ -90,7 +86,6 @@ def get_command_handler():
global _command_handler_module global _command_handler_module
if _command_handler_module is None: if _command_handler_module is None:
from importlib import import_module from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command') _command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module return _command_handler_module
@@ -100,14 +95,12 @@ def get_entities():
global _entities_module global _entities_module
if _entities_module is None: if _entities_module is None:
from importlib import import_module from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities') _entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module return _entities_module
# ============== REAL CommandHandler Tests ============== # ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain') @pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal: class TestCommandHandlerReal:
"""Tests for real CommandHandler class.""" """Tests for real CommandHandler class."""
@@ -134,7 +127,6 @@ class TestCommandHandlerReal:
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
executed_commands = [] executed_commands = []
async def track_execute(command_text, full_command_text, query, session): async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text) executed_commands.append(command_text)
ret = Mock() ret = Mock()
@@ -342,7 +334,8 @@ class TestCommandHandlerReal:
command = get_command_handler() command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory( fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:', image_url='https://example.com/image.png' text='Here is the image:',
image_url='https://example.com/image.png'
) )
handler = command.CommandHandler(fake_app) handler = command.CommandHandler(fake_app)
@@ -400,4 +393,4 @@ class TestCommandHandlerHelper:
command = get_command_handler() command = get_command_handler()
handler = command.CommandHandler(fake_app) handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line') result = handler.cut_str('first line\nsecond line')
assert '...' in result assert '...' in result

View File

@@ -126,9 +126,11 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])] query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="very long response")])
]
result = await stage.process(query, 'LongTextProcessStage') result = await stage.process(query, 'LongTextProcessStage')
@@ -149,9 +151,11 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])] query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="short response")])
]
result = await stage.process(query, 'LongTextProcessStage') result = await stage.process(query, 'LongTextProcessStage')
@@ -175,13 +179,14 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
# Non-Plain component (Image) # Non-Plain component (Image)
query.resp_message_chain = [ query.resp_message_chain = [
platform_message.MessageChain( platform_message.MessageChain([
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')] platform_message.Plain(text="short"),
) platform_message.Image(url="https://example.com/img.png")
])
] ]
result = await stage.process(query, 'LongTextProcessStage') result = await stage.process(query, 'LongTextProcessStage')
@@ -208,7 +213,7 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
@@ -227,7 +232,7 @@ class TestLongTextProcessStageProcess:
stage = longtext.LongTextProcessStage(app) stage = longtext.LongTextProcessStage(app)
stage.strategy_impl = AsyncMock() stage.strategy_impl = AsyncMock()
query = text_query('hello') query = text_query("hello")
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1) query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
query.resp_message_chain = [] query.resp_message_chain = []
@@ -237,7 +242,6 @@ class TestLongTextProcessStageProcess:
assert result.new_query is query assert result.new_query is query
stage.strategy_impl.process.assert_not_called() stage.strategy_impl.process.assert_not_called()
class TestForwardStrategy: class TestForwardStrategy:
"""Tests for ForwardComponentStrategy.""" """Tests for ForwardComponentStrategy."""
@@ -256,7 +260,7 @@ class TestForwardStrategy:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
# Create a mock adapter with bot_account_id # Create a mock adapter with bot_account_id
mock_adapter = Mock() mock_adapter = Mock()
@@ -264,8 +268,10 @@ class TestForwardStrategy:
query.adapter = mock_adapter query.adapter = mock_adapter
# Long text exceeding threshold # Long text exceeding threshold
long_text = 'This is a very long response that exceeds the threshold' long_text = "This is a very long response that exceeds the threshold"
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])] query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=long_text)])
]
result = await stage.process(query, 'LongTextProcessStage') result = await stage.process(query, 'LongTextProcessStage')
@@ -291,13 +297,13 @@ class TestForwardStrategy:
await strat.initialize() await strat.initialize()
query = text_query('hello') query = text_query("hello")
query.pipeline_config = make_longtext_config() query.pipeline_config = make_longtext_config()
mock_adapter = Mock() mock_adapter = Mock()
mock_adapter.bot_account_id = '12345' mock_adapter.bot_account_id = '12345'
query.adapter = mock_adapter query.adapter = mock_adapter
components = await strat.process('test message', query) components = await strat.process("test message", query)
assert len(components) == 1 assert len(components) == 1
assert isinstance(components[0], platform_message.Forward) assert isinstance(components[0], platform_message.Forward)
@@ -320,12 +326,14 @@ class TestLongTextThreshold:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
# Text below threshold # Text below threshold
short_text = 'x' * (threshold - 1) short_text = "x" * (threshold - 1)
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])] query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=short_text)])
]
result = await stage.process(query, 'LongTextProcessStage') result = await stage.process(query, 'LongTextProcessStage')

View File

@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit) # Create query with 3 messages (within limit)
query = text_query('current message') query = text_query("current message")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [ query.messages = [
provider_message.Message(role='user', content='message 1'), provider_message.Message(role='user', content='message 1'),
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
# Create query with many messages exceeding limit # Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user # 7 messages = 3 full rounds + 1 current user
query = text_query('current message') query = text_query("current message")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [ query.messages = [
provider_message.Message(role='user', content='message 1'), provider_message.Message(role='user', content='message 1'),
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [] query.messages = []
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [ query.messages = [
provider_message.Message(role='user', content='hello'), provider_message.Message(role='user', content='hello'),
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('current') query = text_query("current")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [ query.messages = [
provider_message.Message(role='user', content='user1'), provider_message.Message(role='user', content='user1'),
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('current') query = text_query("current")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.messages = [ query.messages = [
provider_message.Message(role='user', content='old1'), provider_message.Message(role='user', content='old1'),
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
trun = trun_cls(app) trun = trun_cls(app)
break break
query = text_query('hello') query = text_query("hello")
query.pipeline_config = make_truncate_config(max_round=3) query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [ query.messages = [
provider_message.Message(role='user', content='m1'), provider_message.Message(role='user', content='m1'),

View File

@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query('hello world') query = text_query("hello world")
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query('test message') query = text_query("test message")
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -194,16 +194,13 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
# Image query with base64 # Image query with base64
query = image_query(text='look at this', url=None) query = image_query(text="look at this", url=None)
# Set base64 on the image component # Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([
chain = platform_message.MessageChain( platform_message.Plain(text="look at this"),
[ platform_message.Image(base64="data:image/png;base64,abc123"),
platform_message.Plain(text='look at this'), ])
platform_message.Image(base64='data:image/png;base64,abc123'),
]
)
query.message_chain = chain query.message_chain = chain
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -241,7 +238,7 @@ class TestPreProcessorImageSegment:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = image_query(text='describe this') query = image_query(text="describe this")
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -279,7 +276,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query('hello') query = text_query("hello")
# Set pipeline config with primary model # Set pipeline config with primary model
query.pipeline_config = { query.pipeline_config = {
@@ -338,7 +335,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = { query.pipeline_config = {
'ai': { 'ai': {
@@ -387,7 +384,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = text_query('hello', sender_id=67890) query = text_query("hello", sender_id=67890)
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')
@@ -424,7 +421,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app) stage = preproc.PreProcessor(app)
query = group_text_query('hello', group_id=99999) query = group_text_query("hello", group_id=99999)
result = await stage.process(query, 'PreProcessor') result = await stage.process(query, 'PreProcessor')

View File

@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
'safety': { 'safety': {
'rate-limit': { 'rate-limit': {
'window-length': 60, # 60 seconds window 'window-length': 60, # 60 seconds window
'limitation': 10, # 10 requests per window 'limitation': 10, # 10 requests per window
'strategy': 'drop', 'strategy': 'drop',
} }
} }
@@ -75,9 +75,11 @@ class TestFixedWindowAlgo:
# Make requests within limit # Make requests within limit
for i in range(10): for i in range(10):
result = await algo.require_access( result = await algo.require_access(
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345' sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
) )
assert result is True, f'Request {i + 1} should be allowed' assert result is True, f"Request {i+1} should be allowed"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit): async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -89,12 +91,20 @@ class TestFixedWindowAlgo:
# Exhaust the limit # Exhaust the limit
for i in range(10): for i in range(10):
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
# Next request should be denied # Next request should be denied
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
assert result is False, 'Request exceeding limit should be denied' assert result is False, "Request exceeding limit should be denied"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit): async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -106,14 +116,20 @@ class TestFixedWindowAlgo:
# Exhaust limit for session 1 # Exhaust limit for session 1
for i in range(10): for i in range(10):
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1') await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session1'
)
# Session 2 should still have its own limit # Session 2 should still have its own limit
result = await algo.require_access( result = await algo.require_access(
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2' sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session2'
) )
assert result is True, 'Different session should have independent limit' assert result is True, "Different session should have independent limit"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query): async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
@@ -134,11 +150,19 @@ class TestFixedWindowAlgo:
await algo.initialize() await algo.initialize()
# First request allowed # First request allowed
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345') result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
assert result1 is True assert result1 is True
# Second request denied # Second request denied
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345') result2 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
assert result2 is False assert result2 is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -150,7 +174,11 @@ class TestFixedWindowAlgo:
await algo.initialize() await algo.initialize()
# First request creates container # First request creates container
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation) # Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
expected_key = 'LauncherTypes.PERSON_12345' expected_key = 'LauncherTypes.PERSON_12345'
@@ -202,7 +230,7 @@ class TestFixedWindowAlgo:
# New request should be allowed (new window) # New request should be allowed (new window)
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test') result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
assert result is True, 'New window should allow new requests' assert result is True, "New window should allow new requests"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query): async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
@@ -228,21 +256,29 @@ class TestFixedWindowAlgo:
# First request allowed # First request allowed
start_time = time.time() start_time = time.time()
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
assert result1 is True assert result1 is True
# Exhaust limit # Exhaust limit
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
# Third request should wait and then succeed # Third request should wait and then succeed
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') result3 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
elapsed = time.time() - start_time elapsed = time.time() - start_time
assert result3 is True, 'After wait, request should succeed' assert result3 is True, "After wait, request should succeed"
# Should have waited approximately until next window # Should have waited approximately until next window
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance) # With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
# Note: This is a timing-sensitive test, so we use a generous tolerance # Note: This is a timing-sensitive test, so we use a generous tolerance
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s' assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit): async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -253,7 +289,11 @@ class TestFixedWindowAlgo:
await algo.initialize() await algo.initialize()
# release_access is empty in current implementation # release_access is empty in current implementation
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') await algo.release_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
# Should not raise or change state # Should not raise or change state
assert 'person_12345' not in algo.containers assert 'person_12345' not in algo.containers

View File

@@ -55,7 +55,7 @@ def make_session():
launcher_type=provider_session.LauncherTypes.PERSON, launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345, launcher_id=12345,
sender_id=12345, sender_id=12345,
use_prompt_name='default', use_prompt_name="default",
using_conversation=None, using_conversation=None,
conversations=[], conversations=[],
) )
@@ -93,9 +93,11 @@ class TestResponseWrapperMessageChain:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])] query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_message_chain = [] query.resp_message_chain = []
results = [] results = []
@@ -123,7 +125,7 @@ class TestResponseWrapperCommand:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
@@ -131,7 +133,7 @@ class TestResponseWrapperCommand:
command_resp = Mock() command_resp = Mock()
command_resp.role = 'command' command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock( command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
) )
query.resp_messages = [command_resp] query.resp_messages = [command_resp]
@@ -161,7 +163,7 @@ class TestResponseWrapperPlugin:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
@@ -169,7 +171,7 @@ class TestResponseWrapperPlugin:
plugin_resp = Mock() plugin_resp = Mock()
plugin_resp.role = 'plugin' plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock( plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
) )
query.resp_messages = [plugin_resp] query.resp_messages = [plugin_resp]
@@ -209,17 +211,17 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
# Create assistant response with content # Create assistant response with content
assistant_resp = Mock() assistant_resp = Mock()
assistant_resp.role = 'assistant' assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello back!' assistant_resp.content = "Hello back!"
assistant_resp.tool_calls = None assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock( assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
) )
query.resp_messages = [assistant_resp] query.resp_messages = [assistant_resp]
@@ -245,7 +247,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
@@ -290,7 +292,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
@@ -301,10 +303,10 @@ class TestResponseWrapperAssistant:
assistant_resp = Mock() assistant_resp = Mock()
assistant_resp.role = 'assistant' assistant_resp.role = 'assistant'
assistant_resp.content = 'Processing...' assistant_resp.content = "Processing..."
assistant_resp.tool_calls = [mock_tool_call] assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock( assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
) )
query.resp_messages = [assistant_resp] query.resp_messages = [assistant_resp]
@@ -344,17 +346,17 @@ class TestResponseWrapperInterrupt:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
# Create assistant response with content # Create assistant response with content
assistant_resp = Mock() assistant_resp = Mock()
assistant_resp.role = 'assistant' assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello!' assistant_resp.content = "Hello!"
assistant_resp.tool_calls = None assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock( assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
) )
query.resp_messages = [assistant_resp] query.resp_messages = [assistant_resp]
@@ -382,7 +384,7 @@ class TestResponseWrapperCustomReply:
app.sess_mgr.get_session = AsyncMock(return_value=session) app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply # Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')]) custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
mock_event_ctx = Mock() mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False) mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock() mock_event_ctx.event = Mock()
@@ -395,17 +397,17 @@ class TestResponseWrapperCustomReply:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
# Create assistant response # Create assistant response
assistant_resp = Mock() assistant_resp = Mock()
assistant_resp.role = 'assistant' assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply' assistant_resp.content = "Default reply"
assistant_resp.tool_calls = None assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock( assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
) )
query.resp_messages = [assistant_resp] query.resp_messages = [assistant_resp]
@@ -419,7 +421,7 @@ class TestResponseWrapperCustomReply:
assert len(results[0].new_query.resp_message_chain) == 1 assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain # Should be the custom chain
chain = results[0].new_query.resp_message_chain[0] chain = results[0].new_query.resp_message_chain[0]
assert 'Custom reply' in str(chain) assert "Custom reply" in str(chain)
class TestResponseWrapperVariables: class TestResponseWrapperVariables:
@@ -450,7 +452,7 @@ class TestResponseWrapperVariables:
await stage.initialize(pipeline_config) await stage.initialize(pipeline_config)
query = text_query('hello') query = text_query("hello")
query.pipeline_config = pipeline_config query.pipeline_config = pipeline_config
query.resp_message_chain = [] query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2'] query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
@@ -458,10 +460,10 @@ class TestResponseWrapperVariables:
# Create assistant response # Create assistant response
assistant_resp = Mock() assistant_resp = Mock()
assistant_resp.role = 'assistant' assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello' assistant_resp.content = "Hello"
assistant_resp.tool_calls = None assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock( assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')]) return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
) )
query.resp_messages = [assistant_resp] query.resp_messages = [assistant_resp]

View File

@@ -6,7 +6,6 @@ Tests cover:
- RAG methods (ingest, retrieve, schema) - RAG methods (ingest, retrieve, schema)
- Disabled plugin early returns - Disabled plugin early returns
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -87,12 +86,16 @@ class TestListPlugins:
return_value=[ return_value=[
{ {
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}}, 'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}], 'components': [
{'manifest': {'manifest': {'kind': 'Command'}}}
],
'debug': False, 'debug': False,
}, },
{ {
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}}, 'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}], 'components': [
{'manifest': {'manifest': {'kind': 'Tool'}}}
],
'debug': False, 'debug': False,
}, },
] ]
@@ -124,7 +127,9 @@ class TestListPlugins:
}, },
] ]
) )
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([]))) connector.ap.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(__iter__=lambda self: iter([]))
)
result = await connector.list_plugins() result = await connector.list_plugins()
@@ -225,8 +230,7 @@ class TestCallParser:
) )
connector.handler.parse_document.assert_called_once_with( connector.handler.parse_document.assert_called_once_with(
'author', 'author', 'parser',
'parser',
{'mime_type': 'text/plain', 'filename': 'test.txt'}, {'mime_type': 'text/plain', 'filename': 'test.txt'},
b'file content', b'file content',
) )
@@ -247,7 +251,9 @@ class TestRAGMethods:
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'}) result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'}) connector.handler.rag_ingest_document.assert_called_once_with(
'author', 'engine', {'file': 'test.pdf'}
)
assert result['status'] == 'success' assert result['status'] == 'success'
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -258,16 +264,14 @@ class TestRAGMethods:
connector.handler = AsyncMock() connector.handler = AsyncMock()
connector.handler.retrieve_knowledge = AsyncMock( connector.handler.retrieve_knowledge = AsyncMock(
return_value={ return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
'results': [
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
]
}
) )
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'}) result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'}) connector.handler.retrieve_knowledge.assert_called_once_with(
'author', 'engine', '', {'query': 'test'}
)
assert result == { assert result == {
'results': [ 'results': [
{ {
@@ -286,7 +290,9 @@ class TestRAGMethods:
connector = create_mock_connector() connector = create_mock_connector()
connector.handler = AsyncMock() connector.handler = AsyncMock()
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}}) connector.handler.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
result = await connector.get_rag_creation_schema('author/engine') result = await connector.get_rag_creation_schema('author/engine')
@@ -320,7 +326,9 @@ class TestRAGMethods:
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'}) await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'}) connector.handler.rag_on_kb_create.assert_called_once_with(
'author', 'engine', 'kb-uuid', {'model': 'test'}
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rag_on_kb_delete(self): async def test_rag_on_kb_delete(self):
@@ -346,7 +354,9 @@ class TestRAGMethods:
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid') result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid') connector.handler.rag_delete_document.assert_called_once_with(
'author', 'engine', 'doc-uuid', 'kb-uuid'
)
assert result is True assert result is True
@@ -436,7 +446,9 @@ class TestGetPluginInfo:
connector = create_mock_connector() connector = create_mock_connector()
connector.handler = AsyncMock() connector.handler = AsyncMock()
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}}) connector.handler.get_plugin_info = AsyncMock(
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
)
result = await connector.get_plugin_info('author', 'plugin') result = await connector.get_plugin_info('author', 'plugin')
@@ -458,7 +470,9 @@ class TestSetPluginConfig:
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'}) await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'}) connector.handler.set_plugin_config.assert_called_once_with(
'author', 'plugin', {'setting': 'value'}
)
class TestPingPluginRuntime: class TestPingPluginRuntime:

View File

@@ -3,7 +3,6 @@
Tests cover: Tests cover:
- _parse_plugin_id() parsing and validation - _parse_plugin_id() parsing and validation
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest

View File

@@ -6,7 +6,6 @@ Tests cover:
- Handling missing requirements.txt - Handling missing requirements.txt
- Handling empty/malformed requirements.txt - Handling empty/malformed requirements.txt
""" """
from __future__ import annotations from __future__ import annotations
import zipfile import zipfile
@@ -83,13 +82,13 @@ class TestExtractDepsMetadata:
"""Test that comments and empty lines are filtered.""" """Test that comments and empty lines are filtered."""
connector_instance = create_mock_connector() connector_instance = create_mock_connector()
requirements = """# This is a comment requirements = '''# This is a comment
requests>=2.0 requests>=2.0
# Another comment # Another comment
flask==1.0 flask==1.0
numpy""" numpy'''
zip_bytes = create_zip_with_requirements(requirements) zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock() task_context = Mock()
@@ -148,9 +147,9 @@ numpy"""
"""Test handling requirements.txt with only comments.""" """Test handling requirements.txt with only comments."""
connector_instance = create_mock_connector() connector_instance = create_mock_connector()
requirements = """# Comment 1 requirements = '''# Comment 1
# Comment 2 # Comment 2
# Comment 3""" # Comment 3'''
zip_bytes = create_zip_with_requirements(requirements) zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock() task_context = Mock()

View File

@@ -40,13 +40,11 @@ class TestHandlerQueryVariables:
"""Test set_query_var returns error when query not found.""" """Test set_query_var returns error when query not found."""
runtime_handler = make_handler(mock_app) runtime_handler = make_handler(mock_app)
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]( response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
{ 'query_id': 'nonexistent-query',
'query_id': 'nonexistent-query', 'key': 'test_var',
'key': 'test_var', 'value': 'test_value',
'value': 'test_value', })
}
)
assert response.code != 0 assert response.code != 0
assert 'nonexistent-query' in response.message assert 'nonexistent-query' in response.message
@@ -60,13 +58,11 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]( response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
{ 'query_id': 'test-query',
'query_id': 'test-query', 'key': 'test_var',
'key': 'test_var', 'value': 'test_value',
'value': 'test_value', })
}
)
assert response.code == 0 assert response.code == 0
assert mock_query.variables['test_var'] == 'test_value' assert mock_query.variables['test_var'] == 'test_value'
@@ -80,12 +76,10 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]( response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
{ 'query_id': 'test-query',
'query_id': 'test-query', 'key': 'existing_var',
'key': 'existing_var', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == {'value': 'existing_value'} assert response.data == {'value': 'existing_value'}
@@ -99,11 +93,9 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]( response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
{ 'query_id': 'test-query',
'query_id': 'test-query', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == {'vars': mock_query.variables} assert response.data == {'vars': mock_query.variables}
@@ -116,7 +108,7 @@ class TestHandlerRagErrorResponse:
"""Test basic error response creation.""" """Test basic error response creation."""
from langbot.pkg.plugin.handler import _make_rag_error_response from langbot.pkg.plugin.handler import _make_rag_error_response
error = Exception('test error') error = Exception("test error")
response = _make_rag_error_response(error, 'TestError') response = _make_rag_error_response(error, 'TestError')
# ActionResponse is a pydantic model, check message field # ActionResponse is a pydantic model, check message field
@@ -128,8 +120,13 @@ class TestHandlerRagErrorResponse:
"""Test error response with extra context.""" """Test error response with extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response from langbot.pkg.plugin.handler import _make_rag_error_response
error = ValueError('invalid input') error = ValueError("invalid input")
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test') response = _make_rag_error_response(
error,
'ValidationError',
field='name',
value='test'
)
assert 'ValidationError' in response.message assert 'ValidationError' in response.message
assert 'field=name' in response.message assert 'field=name' in response.message
@@ -140,7 +137,7 @@ class TestHandlerRagErrorResponse:
"""Test error response includes exception type.""" """Test error response includes exception type."""
from langbot.pkg.plugin.handler import _make_rag_error_response from langbot.pkg.plugin.handler import _make_rag_error_response
error = RuntimeError('connection failed') error = RuntimeError("connection failed")
response = _make_rag_error_response(error, 'ConnectionError') response = _make_rag_error_response(error, 'ConnectionError')
assert 'RuntimeError' in response.message assert 'RuntimeError' in response.message
@@ -151,7 +148,7 @@ class TestHandlerRagErrorResponse:
"""Test error response with no extra context.""" """Test error response with no extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response from langbot.pkg.plugin.handler import _make_rag_error_response
error = KeyError('missing_key') error = KeyError("missing_key")
response = _make_rag_error_response(error, 'LookupError') response = _make_rag_error_response(error, 'LookupError')
# No context parts means no brackets # No context parts means no brackets

View File

@@ -47,14 +47,12 @@ class TestInitializePluginSettings:
Mock(), Mock(),
] ]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
{ 'plugin_author': 'test-author',
'plugin_author': 'test-author', 'plugin_name': 'test-plugin',
'plugin_name': 'test-plugin', 'install_source': 'local',
'install_source': 'local', 'install_info': {'path': '/test'},
'install_info': {'path': '/test'}, })
}
)
assert response.code == 0 assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2 assert app.persistence_mgr.execute_async.await_count == 2
@@ -84,14 +82,12 @@ class TestInitializePluginSettings:
Mock(), Mock(),
] ]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
{ 'plugin_author': 'test-author',
'plugin_author': 'test-author', 'plugin_name': 'test-plugin',
'plugin_name': 'test-plugin', 'install_source': 'github',
'install_source': 'github', 'install_info': {'repo': 'author/name'},
'install_info': {'repo': 'author/name'}, })
}
)
assert response.code == 0 assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 3 assert app.persistence_mgr.execute_async.await_count == 3
@@ -165,7 +161,9 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old')) app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new')) response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'new')
)
assert response.code == 0 assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2 assert app.persistence_mgr.execute_async.await_count == 2
@@ -205,7 +203,9 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0 app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x')) response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'x')
)
assert response.code != 0 assert response.code != 0
assert '1 > 0 bytes' in response.message assert '1 > 0 bytes' in response.message
@@ -228,12 +228,10 @@ class TestGetPluginSettings:
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result() app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
{ 'plugin_author': 'test-author',
'plugin_author': 'test-author', 'plugin_name': 'test-plugin',
'plugin_name': 'test-plugin', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == { assert response.data == {
@@ -257,12 +255,10 @@ class TestGetPluginSettings:
) )
app.persistence_mgr.execute_async.return_value = make_result(setting) app.persistence_mgr.execute_async.return_value = make_result(setting)
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
{ 'plugin_author': 'test-author',
'plugin_author': 'test-author', 'plugin_name': 'test-plugin',
'plugin_name': 'test-plugin', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == { assert response.data == {
@@ -290,13 +286,11 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content')) app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
{ 'key': 'test-key',
'key': 'test-key', 'owner_type': 'plugin',
'owner_type': 'plugin', 'owner': 'test-owner',
'owner': 'test-owner', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == { assert response.data == {
@@ -309,13 +303,11 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result() app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]( response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
{ 'key': 'test-key',
'key': 'test-key', 'owner_type': 'plugin',
'owner_type': 'plugin', 'owner': 'test-owner',
'owner': 'test-owner', })
}
)
assert response.code != 0 assert response.code != 0
assert 'Storage with key test-key not found' in response.message assert 'Storage with key test-key not found' in response.message
@@ -337,11 +329,9 @@ class TestHandlerQueryLookup:
"""Query-bound actions return error when query_id is not cached.""" """Query-bound actions return error when query_id is not cached."""
runtime_handler = make_handler(app) runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]( response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
{ 'query_id': 'nonexistent-query',
'query_id': 'nonexistent-query', })
}
)
assert response.code != 0 assert response.code != 0
assert 'nonexistent-query' in response.message assert 'nonexistent-query' in response.message
@@ -353,11 +343,9 @@ class TestHandlerQueryLookup:
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid') query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
app.query_pool.cached_queries['existing-query'] = query app.query_pool.cached_queries['existing-query'] = query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]( response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
{ 'query_id': 'existing-query',
'query_id': 'existing-query', })
}
)
assert response.code == 0 assert response.code == 0
assert response.data == {'bot_uuid': 'test-bot-uuid'} assert response.data == {'bot_uuid': 'test-bot-uuid'}

View File

@@ -4,7 +4,6 @@ Tests cover:
- _make_rag_error_response() helper function - _make_rag_error_response() helper function
- RuntimeConnectionHandler cleanup_plugin_data method - RuntimeConnectionHandler cleanup_plugin_data method
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -24,7 +23,7 @@ class TestMakeRagErrorResponse:
"""Test basic error response creation.""" """Test basic error response creation."""
handler = get_handler_module() handler = get_handler_module()
error = ValueError('test error message') error = ValueError("test error message")
result = handler._make_rag_error_response(error, 'TestError') result = handler._make_rag_error_response(error, 'TestError')
# ActionResponse.error() returns code=1 (error status) # ActionResponse.error() returns code=1 (error status)
@@ -37,7 +36,7 @@ class TestMakeRagErrorResponse:
"""Test that error type is included in message.""" """Test that error type is included in message."""
handler = get_handler_module() handler = get_handler_module()
error = RuntimeError('something went wrong') error = RuntimeError("something went wrong")
result = handler._make_rag_error_response(error, 'VectorStoreError') result = handler._make_rag_error_response(error, 'VectorStoreError')
assert '[VectorStoreError/RuntimeError]' in result.message assert '[VectorStoreError/RuntimeError]' in result.message
@@ -46,7 +45,7 @@ class TestMakeRagErrorResponse:
"""Test that extra context fields are included.""" """Test that extra context fields are included."""
handler = get_handler_module() handler = get_handler_module()
error = Exception('embedding failed') error = Exception("embedding failed")
result = handler._make_rag_error_response( result = handler._make_rag_error_response(
error, error,
'EmbeddingError', 'EmbeddingError',
@@ -72,7 +71,7 @@ class TestMakeRagErrorResponse:
"""Test multiple context fields are comma separated.""" """Test multiple context fields are comma separated."""
handler = get_handler_module() handler = get_handler_module()
error = IOError('file not found') error = IOError("file not found")
result = handler._make_rag_error_response( result = handler._make_rag_error_response(
error, error,
'FileServiceError', 'FileServiceError',
@@ -120,7 +119,9 @@ class TestCleanupPluginData:
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler) handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
handler_instance.ap = mock_app handler_instance.ap = mock_app
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name') await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
handler_instance, 'author', 'plugin-name'
)
# Should have at least 2 calls: one for settings, one for binary storage # Should have at least 2 calls: one for settings, one for binary storage
assert mock_app.persistence_mgr.execute_async.call_count >= 2 assert mock_app.persistence_mgr.execute_async.call_count >= 2

View File

@@ -88,10 +88,7 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False): async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
return provider_message.Message(
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
)
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}): async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results.""" """Return fake rerank results."""
@@ -138,10 +135,8 @@ def mock_app_for_modelmgr():
# Fake persistence manager - returns empty results by default # Fake persistence manager - returns empty results by default
app.persistence_mgr = SimpleNamespace() app.persistence_mgr = SimpleNamespace()
async def default_execute(query): async def default_execute(query):
return _make_mock_result([]) return _make_mock_result([])
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute) app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
# Fake discover engine # Fake discover engine
@@ -170,7 +165,9 @@ def fake_requester_registry(mock_app_for_modelmgr):
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester) fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester) another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component]) app.discover.get_components_by_kind = Mock(
return_value=[fake_component, another_component]
)
model_mgr = ModelManager(app) model_mgr = ModelManager(app)
return model_mgr return model_mgr

View File

@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
'base-url': 'https://api.dify.ai', 'base-url': 'https://api.dify.ai',
} }
}, },
'output': {'misc': {}}, 'output': {'misc': {}}
} }
runner = DifyServiceAPIRunner(mock_app, pipeline_config) runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai', 'base-url': 'https://api.dify.ai',
} }
}, },
'output': {'misc': {}}, 'output': {'misc': {}}
} }
with pytest.raises(DifyAPIError, match='不支持'): with pytest.raises(DifyAPIError, match='不支持'):
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai', 'base-url': 'https://api.dify.ai',
} }
}, },
'output': {'misc': {}}, 'output': {'misc': {}}
} }
runner = DifyServiceAPIRunner(mock_app, pipeline_config) runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
'base-url': 'https://api.dify.ai', 'base-url': 'https://api.dify.ai',
} }
}, },
'output': {'misc': {}}, 'output': {'misc': {}}
} }
runner = DifyServiceAPIRunner(mock_app, pipeline_config) runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app assert runner.ap == mock_app

View File

@@ -1062,7 +1062,9 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info: with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.side_effect = ( mock_get_model_info.side_effect = (
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {} lambda model: {'max_input_tokens': 131072}
if model == 'moonshot/moonshot-v1-128k'
else {}
) )
assert requester._safe_context_length('moonshot-v1-128k') == 131072 assert requester._safe_context_length('moonshot-v1-128k') == 131072

View File

@@ -635,9 +635,7 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider( async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel.""" """Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
model_mgr = fake_requester_registry model_mgr = fake_requester_registry
@@ -650,9 +648,7 @@ async def test_model_manager_load_llm_model_with_provider(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider_from_row( async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider handles Row objects.""" """Test ModelManager.load_llm_model_with_provider handles Row objects."""
model_mgr = fake_requester_registry model_mgr = fake_requester_registry
@@ -665,9 +661,7 @@ async def test_model_manager_load_llm_model_with_provider_from_row(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_manager_load_embedding_model_with_provider( async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel.""" """Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
model_mgr = fake_requester_registry model_mgr = fake_requester_registry

View File

@@ -43,7 +43,6 @@ class TestableRequester(requester.ProviderAPIRequester):
remove_think=False, remove_think=False,
): ):
import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message( return provider_message.Message(
role='assistant', role='assistant',
content=[provider_message.ContentElement(type='text', text='Testable response')], content=[provider_message.ContentElement(type='text', text='Testable response')],
@@ -290,9 +289,7 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
current_stage_name=None, current_stage_name=None,
) )
messages = [ messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
result = await provider.invoke_llm(query, runtime_llm_model, messages) result = await provider.invoke_llm(query, runtime_llm_model, messages)
@@ -333,9 +330,7 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
current_stage_name=None, current_stage_name=None,
) )
messages = [ messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
chunks = [] chunks = []
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages): async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
@@ -581,9 +576,7 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
current_stage_name=None, current_stage_name=None,
) )
messages = [ messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
with pytest.raises(RequesterError): with pytest.raises(RequesterError):
await provider.invoke_llm(query, model, messages) await provider.invoke_llm(query, model, messages)

View File

@@ -5,7 +5,6 @@ Tests cover:
- Conversation creation with prompts - Conversation creation with prompts
- Session concurrency semaphore - Session concurrency semaphore
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -61,7 +60,11 @@ class TestSessionManagerGetSession:
"""Create mock app with instance config.""" """Create mock app with instance config."""
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = {'concurrency': {'session': 5}} mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
return mock_app return mock_app
@pytest.fixture @pytest.fixture
@@ -170,7 +173,11 @@ class TestSessionManagerGetConversation:
"""Create mock app with instance config.""" """Create mock app with instance config."""
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = {'concurrency': {'session': 5}} mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
return mock_app return mock_app
@pytest.fixture @pytest.fixture
@@ -194,13 +201,17 @@ class TestSessionManagerGetConversation:
return query return query
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session): async def test_creates_conversation_with_prompt(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that get_conversation creates conversation with prompt.""" """Test that get_conversation creates conversation with prompt."""
sessionmgr = get_session_module() sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config) manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
pipeline_uuid = 'pipeline-123' pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123' bot_uuid = 'bot-123'
@@ -223,15 +234,21 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config) manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
pipeline_uuid = 'pipeline-123' pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123' bot_uuid = 'bot-123'
# First call creates conversation # First call creates conversation
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid) conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
# Second call with same pipeline should return same conversation # Second call with same pipeline should return same conversation
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid) conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
assert conv1 is conv2 assert conv1 is conv2
assert len(sample_session.conversations) == 1 assert len(sample_session.conversations) == 1
@@ -245,26 +262,36 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config) manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
# First call with pipeline1 # First call with pipeline1
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1') conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
)
# Second call with different pipeline should create new conversation # Second call with different pipeline should create new conversation
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2') conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
)
assert conv1 is not conv2 assert conv1 is not conv2
assert len(sample_session.conversations) == 2 assert len(sample_session.conversations) == 2
assert sample_session.using_conversation is conv2 assert sample_session.using_conversation is conv2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session): async def test_conversation_has_empty_messages(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that created conversation has empty messages list.""" """Test that created conversation has empty messages list."""
sessionmgr = get_session_module() sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config) manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
conversation = await manager.get_conversation( conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123' sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
@@ -273,17 +300,22 @@ class TestSessionManagerGetConversation:
assert conversation.messages == [] assert conversation.messages == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session): async def test_prompt_messages_from_config(
self, mock_app_with_config, sample_query, sample_session
):
"""Test that prompt messages are created from prompt_config.""" """Test that prompt messages are created from prompt_config."""
sessionmgr = get_session_module() sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config) manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}] prompt_config = [
{'role': 'system', 'content': 'System message'},
{'role': 'user', 'content': 'User message'}
]
conversation = await manager.get_conversation( conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123' sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
) )
assert conversation.prompt.name == 'default' assert conversation.prompt.name == 'default'
assert len(conversation.prompt.messages) == 2 assert len(conversation.prompt.messages) == 2

View File

@@ -136,7 +136,6 @@ class TestToolManagerSchemaGeneration:
assert 'description' in func assert 'description' in func
assert 'parameters' in func assert 'parameters' in func
class TestToolManagerExecuteFuncCall: class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method.""" """Tests for execute_func_call method."""

View File

@@ -3,7 +3,6 @@
Tests cover: Tests cover:
- _to_i18n_name() static method - _to_i18n_name() static method
""" """
from __future__ import annotations from __future__ import annotations
from importlib import import_module from importlib import import_module
@@ -61,4 +60,4 @@ class TestToI18nName:
kbmgr = get_kbmgr_module() kbmgr = get_kbmgr_module()
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'} input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
result = kbmgr.RAGManager._to_i18n_name(input_dict) result = kbmgr.RAGManager._to_i18n_name(input_dict)
assert result == {'en_US': 'English', 'extra_key': 'extra_value'} assert result == {'en_US': 'English', 'extra_key': 'extra_value'}

View File

@@ -6,7 +6,6 @@ Tests cover:
- Knowledge engine enrichment - Knowledge engine enrichment
- KB loading and removal - KB loading and removal
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -102,9 +101,13 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module() rag_module = get_rag_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}]) mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.persistence_mgr.execute_async = AsyncMock() mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin error')) mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin error')
)
manager = rag_module.RAGManager(mock_app) manager = rag_module.RAGManager(mock_app)
@@ -125,7 +128,9 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module() rag_module = get_rag_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}]) mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.persistence_mgr.execute_async = AsyncMock() mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock() mock_app.plugin_connector.rag_on_kb_create = AsyncMock()
@@ -201,7 +206,9 @@ class TestRuntimeKnowledgeBaseOnKBCreate:
mock_app = create_mock_app() mock_app = create_mock_app()
mock_kb = create_mock_kb_entity() mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin failed')) mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin failed')
)
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -238,7 +245,9 @@ class TestRuntimeKnowledgeBaseIngestDocument:
mock_app = create_mock_app() mock_app = create_mock_app()
mock_kb = create_mock_kb_entity() mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.call_rag_ingest = AsyncMock(return_value={'status': 'success'}) mock_app.plugin_connector.call_rag_ingest = AsyncMock(
return_value={'status': 'success'}
)
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -295,10 +304,14 @@ class TestRAGManagerLoadKnowledgeBasesFromDB:
# KB that will cause initialize to fail # KB that will cause initialize to fail
mock_kb = create_mock_kb_entity() mock_kb = create_mock_kb_entity()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb]))) mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb]))
)
# Make initialize fail by having plugin_connector throw error # Make initialize fail by having plugin_connector throw error
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Init failed')) mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Init failed')
)
manager = rag_module.RAGManager(mock_app) manager = rag_module.RAGManager(mock_app)
# Should not raise - errors are caught # Should not raise - errors are caught
@@ -398,7 +411,9 @@ class TestRuntimeKnowledgeBaseRetrieve:
mock_kb = create_mock_kb_entity() mock_kb = create_mock_kb_entity()
mock_kb.retrieval_settings = {} mock_kb.retrieval_settings = {}
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(return_value={'results': []}) mock_app.plugin_connector.call_rag_retrieve = AsyncMock(
return_value={'results': []}
)
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -667,7 +682,9 @@ class TestRAGManagerGetAllDetails:
"""Test returns empty list when no knowledge bases.""" """Test returns empty list when no knowledge bases."""
rag_module = get_rag_module() rag_module = get_rag_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[]))) mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[]))
)
manager = rag_module.RAGManager(mock_app) manager = rag_module.RAGManager(mock_app)
result = await manager.get_all_knowledge_base_details() result = await manager.get_all_knowledge_base_details()
@@ -682,7 +699,9 @@ class TestRAGManagerGetAllDetails:
# Mock DB result # Mock DB result
mock_kb_row = Mock() mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb_row]))) mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb_row]))
)
mock_app.persistence_mgr.serialize_model = Mock( mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'} return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
) )
@@ -705,7 +724,9 @@ class TestRAGManagerGetDetails:
"""Test returns None when KB doesn't exist.""" """Test returns None when KB doesn't exist."""
rag_module = get_rag_module() rag_module = get_rag_module()
mock_app = create_mock_app() mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=None))) mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=None))
)
manager = rag_module.RAGManager(mock_app) manager = rag_module.RAGManager(mock_app)
result = await manager.get_knowledge_base_details('nonexistent') result = await manager.get_knowledge_base_details('nonexistent')
@@ -719,7 +740,9 @@ class TestRAGManagerGetDetails:
mock_app = create_mock_app() mock_app = create_mock_app()
mock_kb_row = Mock() mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=mock_kb_row))) mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=mock_kb_row))
)
mock_app.persistence_mgr.serialize_model = Mock( mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'} return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
) )
@@ -768,4 +791,4 @@ class TestRAGManagerLoadKnowledgeBase:
await manager.load_knowledge_base(kb_dict) await manager.load_knowledge_base(kb_dict)
assert 'kb-uuid' in manager.knowledge_bases assert 'kb-uuid' in manager.knowledge_bases

View File

@@ -121,12 +121,10 @@ class TestRAGRuntimeServiceVectorSearch:
"""Create mock app.""" """Create mock app."""
mock_app = MagicMock() mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock() mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.search = AsyncMock( mock_app.vector_db_mgr.search = AsyncMock(return_value=[
return_value=[ {'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}}, {'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}}, ])
]
)
return mock_app return mock_app
def _make_rag_import_mocks(self): def _make_rag_import_mocks(self):
@@ -303,7 +301,10 @@ class TestRAGRuntimeServiceVectorList:
mock_app = MagicMock() mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock() mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.list_by_filter = AsyncMock( mock_app.vector_db_mgr.list_by_filter = AsyncMock(
return_value=([{'id': 'id1', 'metadata': {'file_id': 'abc'}}], 10) return_value=(
[{'id': 'id1', 'metadata': {'file_id': 'abc'}}],
10
)
) )
return mock_app return mock_app

View File

@@ -21,8 +21,8 @@ from langbot.pkg.storage.providers.localstorage import LocalStorageProvider
@pytest.fixture @pytest.fixture
def storage_provider(tmp_path): def storage_provider(tmp_path):
"""Create a LocalStorageProvider with a temporary storage path.""" """Create a LocalStorageProvider with a temporary storage path."""
storage_path = str(tmp_path / 'storage') storage_path = str(tmp_path / "storage")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
mock_app = Mock() mock_app = Mock()
provider = LocalStorageProvider(mock_app) provider = LocalStorageProvider(mock_app)
yield provider, storage_path yield provider, storage_path
@@ -35,15 +35,15 @@ class TestPathTraversalPrevention:
async def test_absolute_path_save_rejected(self, storage_provider, tmp_path): async def test_absolute_path_save_rejected(self, storage_provider, tmp_path):
"""Saving with an absolute path key must be blocked.""" """Saving with an absolute path key must be blocked."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
target_file = str(tmp_path / 'pwned.txt') target_file = str(tmp_path / "pwned.txt")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with pytest.raises((ValueError, PermissionError)): with pytest.raises((ValueError, PermissionError)):
await provider.save(target_file, b'malicious content') await provider.save(target_file, b"malicious content")
# The file must NOT exist outside the storage directory # The file must NOT exist outside the storage directory
assert not os.path.exists(target_file), ( assert not os.path.exists(target_file), (
f'Path traversal succeeded: file was written outside storage to {target_file}' f"Path traversal succeeded: file was written outside storage to {target_file}"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -52,28 +52,32 @@ class TestPathTraversalPrevention:
provider, storage_path = storage_provider provider, storage_path = storage_provider
# Create a file outside the storage directory # Create a file outside the storage directory
target_file = str(tmp_path / 'secret.txt') target_file = str(tmp_path / "secret.txt")
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(b'secret data') f.write(b"secret data")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)): with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(target_file) data = await provider.load(target_file)
assert data != b'secret data', 'Path traversal succeeded: read file outside storage' assert data != b"secret data", (
"Path traversal succeeded: read file outside storage"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path): async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path):
"""Exists check with an absolute path key must be blocked or return False.""" """Exists check with an absolute path key must be blocked or return False."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
target_file = str(tmp_path / 'check_me.txt') target_file = str(tmp_path / "check_me.txt")
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(b'data') f.write(b"data")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
try: try:
result = await provider.exists(target_file) result = await provider.exists(target_file)
assert result is False, 'Path traversal succeeded: exists() returned True for file outside storage' assert result is False, (
"Path traversal succeeded: exists() returned True for file outside storage"
)
except (ValueError, PermissionError): except (ValueError, PermissionError):
pass # Expected pass # Expected
@@ -82,26 +86,28 @@ class TestPathTraversalPrevention:
"""Deleting with an absolute path key must be blocked.""" """Deleting with an absolute path key must be blocked."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
target_file = str(tmp_path / 'do_not_delete.txt') target_file = str(tmp_path / "do_not_delete.txt")
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(b'important data') f.write(b"important data")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)): with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.delete(target_file) await provider.delete(target_file)
assert os.path.exists(target_file), 'Path traversal succeeded: file outside storage was deleted' assert os.path.exists(target_file), (
"Path traversal succeeded: file outside storage was deleted"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_absolute_path_size_rejected(self, storage_provider, tmp_path): async def test_absolute_path_size_rejected(self, storage_provider, tmp_path):
"""Size check with an absolute path key must be blocked.""" """Size check with an absolute path key must be blocked."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
target_file = str(tmp_path / 'measure_me.txt') target_file = str(tmp_path / "measure_me.txt")
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(b'some data') f.write(b"some data")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)): with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.size(target_file) await provider.size(target_file)
@@ -110,39 +116,41 @@ class TestPathTraversalPrevention:
"""Relative path traversal with '..' must be blocked.""" """Relative path traversal with '..' must be blocked."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
target_file = str(tmp_path / 'above_storage.txt') target_file = str(tmp_path / "above_storage.txt")
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(b'above storage secret') f.write(b"above storage secret")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
relative_key = os.path.join('..', 'above_storage.txt') relative_key = os.path.join("..", "above_storage.txt")
with pytest.raises((ValueError, PermissionError, FileNotFoundError)): with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(relative_key) data = await provider.load(relative_key)
assert data != b'above storage secret' assert data != b"above storage secret"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path): async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path):
"""delete_dir_recursive with traversal path must be blocked.""" """delete_dir_recursive with traversal path must be blocked."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
outside_dir = tmp_path / 'outside_dir' outside_dir = tmp_path / "outside_dir"
outside_dir.mkdir() outside_dir.mkdir()
(outside_dir / 'file.txt').write_text('important') (outside_dir / "file.txt").write_text("important")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with pytest.raises((ValueError, PermissionError)): with pytest.raises((ValueError, PermissionError)):
await provider.delete_dir_recursive(str(outside_dir)) await provider.delete_dir_recursive(str(outside_dir))
assert outside_dir.exists(), 'Path traversal succeeded: directory outside storage was deleted' assert outside_dir.exists(), (
"Path traversal succeeded: directory outside storage was deleted"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_legitimate_key_works(self, storage_provider): async def test_legitimate_key_works(self, storage_provider):
"""Normal keys without traversal must still work.""" """Normal keys without traversal must still work."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = 'test_image_abc123.png' key = "test_image_abc123.png"
content = b'PNG image data' content = b"PNG image data"
await provider.save(key, content) await provider.save(key, content)
assert await provider.exists(key) is True assert await provider.exists(key) is True
@@ -158,9 +166,9 @@ class TestPathTraversalPrevention:
"""Keys with legitimate subdirectories must still work.""" """Keys with legitimate subdirectories must still work."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = 'bot_log_images/img_001.png' key = "bot_log_images/img_001.png"
content = b'PNG image data' content = b"PNG image data"
await provider.save(key, content) await provider.save(key, content)
assert await provider.exists(key) is True assert await provider.exists(key) is True
@@ -173,33 +181,33 @@ class TestPathTraversalPrevention:
"""delete_dir_recursive should handle non-existing directories gracefully.""" """delete_dir_recursive should handle non-existing directories gracefully."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
# Try to delete a non-existing directory - should not raise # Try to delete a non-existing directory - should not raise
await provider.delete_dir_recursive('nonexistent_dir') await provider.delete_dir_recursive("nonexistent_dir")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_dir_recursive_with_files(self, storage_provider): async def test_delete_dir_recursive_with_files(self, storage_provider):
"""delete_dir_recursive should delete directory with files inside.""" """delete_dir_recursive should delete directory with files inside."""
provider, storage_path = storage_provider provider, storage_path = storage_provider
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
# Create a directory with files # Create a directory with files
key1 = 'test_dir/file1.txt' key1 = "test_dir/file1.txt"
key2 = 'test_dir/file2.txt' key2 = "test_dir/file2.txt"
await provider.save(key1, b'content1') await provider.save(key1, b"content1")
await provider.save(key2, b'content2') await provider.save(key2, b"content2")
# Verify files exist # Verify files exist
assert await provider.exists(key1) assert await provider.exists(key1)
assert await provider.exists(key2) assert await provider.exists(key2)
# Delete directory recursively # Delete directory recursively
await provider.delete_dir_recursive('test_dir') await provider.delete_dir_recursive("test_dir")
# Verify files no longer exist # Verify files no longer exist
assert not await provider.exists(key1) assert not await provider.exists(key1)
assert not await provider.exists(key2) assert not await provider.exists(key2)
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__, '-v']) pytest.main([__file__, "-v"])

View File

@@ -8,7 +8,6 @@ Tests cover:
Uses moto library to mock AWS S3 service. Uses moto library to mock AWS S3 service.
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -45,10 +44,8 @@ def mock_app_with_s3_config():
def s3_mock(): def s3_mock():
"""Set up moto S3 mock context.""" """Set up moto S3 mock context."""
from moto import mock_aws from moto import mock_aws
with mock_aws(): with mock_aws():
import boto3 import boto3
# Create bucket for tests that need pre-existing bucket # Create bucket for tests that need pre-existing bucket
s3 = boto3.client('s3', region_name='us-east-1') s3 = boto3.client('s3', region_name='us-east-1')
yield s3 yield s3
@@ -328,4 +325,4 @@ class TestS3StorageProviderErrorHandling:
await provider.initialize() await provider.initialize()
with pytest.raises(Exception): with pytest.raises(Exception):
await provider.size('nonexistent.txt') await provider.size('nonexistent.txt')

View File

@@ -31,7 +31,7 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
await storage_mgr.initialize() await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
mock_app.logger.info.assert_called() mock_app.logger.info.assert_called()
@@ -41,12 +41,12 @@ class TestStorageMgr:
"""Should use local storage when explicitly configured.""" """Should use local storage when explicitly configured."""
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = {'storage': {'use': 'local'}} mock_app.instance_config.data = {"storage": {"use": "local"}}
mock_app.logger = Mock() mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
await storage_mgr.initialize() await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -55,12 +55,14 @@ class TestStorageMgr:
"""Should use S3 storage when configured.""" """Should use S3 storage when configured."""
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = {'storage': {'use': 's3', 's3': {'endpoint_url': 'https://s3.amazonaws.com'}}} mock_app.instance_config.data = {
"storage": {"use": "s3", "s3": {"endpoint_url": "https://s3.amazonaws.com"}}
}
mock_app.logger = Mock() mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock): with patch.object(S3StorageProvider, "initialize", new_callable=AsyncMock):
await storage_mgr.initialize() await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, S3StorageProvider) assert isinstance(storage_mgr.storage_provider, S3StorageProvider)
@@ -69,12 +71,12 @@ class TestStorageMgr:
"""Should default to local storage for invalid storage type.""" """Should default to local storage for invalid storage type."""
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}} mock_app.instance_config.data = {"storage": {"use": "invalid_type"}}
mock_app.logger = Mock() mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
await storage_mgr.initialize() await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -88,7 +90,9 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: with patch.object(
LocalStorageProvider, "initialize", new_callable=AsyncMock
) as mock_init:
await storage_mgr.initialize() await storage_mgr.initialize()
mock_init.assert_called_once() mock_init.assert_called_once()
@@ -101,8 +105,8 @@ class TestStorageProviderBase:
mock_app = Mock() mock_app = Mock()
# Use LocalStorageProvider as concrete implementation # Use LocalStorageProvider as concrete implementation
with patch('os.path.exists', return_value=True): with patch("os.path.exists", return_value=True):
with patch('os.makedirs'): with patch("os.makedirs"):
provider = LocalStorageProvider(mock_app) provider = LocalStorageProvider(mock_app)
assert provider.ap == mock_app assert provider.ap == mock_app
@@ -111,12 +115,12 @@ class TestStorageProviderBase:
"""Provider base initialize should be callable and do nothing.""" """Provider base initialize should be callable and do nothing."""
mock_app = Mock() mock_app = Mock()
with patch('os.path.exists', return_value=True): with patch("os.path.exists", return_value=True):
with patch('os.makedirs'): with patch("os.makedirs"):
provider = LocalStorageProvider(mock_app) provider = LocalStorageProvider(mock_app)
# Initialize should not raise # Initialize should not raise
await provider.initialize() await provider.initialize()
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__, '-v']) pytest.main([__file__, "-v"])

View File

@@ -8,7 +8,6 @@ Tests cover:
- HTTP request success/failure scenarios - HTTP request success/failure scenarios
- Source code bug: send_tasks should be instance variable - Source code bug: send_tasks should be instance variable
""" """
from __future__ import annotations from __future__ import annotations
import pytest import pytest
@@ -39,7 +38,6 @@ class TestTelemetryManagerInit:
manager = telemetry.TelemetryManager(mock_app) manager = telemetry.TelemetryManager(mock_app)
assert manager.telemetry_config == {} assert manager.telemetry_config == {}
class TestTelemetryManagerInitialize: class TestTelemetryManagerInitialize:
"""Tests for initialize() method.""" """Tests for initialize() method."""
@@ -220,7 +218,7 @@ class TestPayloadSanitization:
# All null string fields should be empty strings # All null string fields should be empty strings
for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']: for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']:
assert result[field] == '', f'Field {field} should be empty string, got {result[field]}' assert result[field] == '', f"Field {field} should be empty string, got {result[field]}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sanitize_string_fields_preserve_values(self): async def test_sanitize_string_fields_preserve_values(self):
@@ -420,7 +418,9 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'} manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock( mock_response = Mock(
status_code=200, text='{"code": 0, "msg": "success"}', json=Mock(return_value={'code': 0, 'msg': 'success'}) status_code=200,
text='{"code": 0, "msg": "success"}',
json=Mock(return_value={'code': 0, 'msg': 'success'})
) )
mock_client = Mock() mock_client = Mock()
@@ -448,7 +448,9 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'} manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock( mock_response = Mock(
status_code=500, text='Internal Server Error', json=Mock(return_value={'code': 500, 'msg': 'error'}) status_code=500,
text='Internal Server Error',
json=Mock(return_value={'code': 500, 'msg': 'error'})
) )
mock_client = Mock() mock_client = Mock()
@@ -476,7 +478,7 @@ class TestHTTPScenarios:
mock_response = Mock( mock_response = Mock(
status_code=200, status_code=200,
text='{"code": 400, "msg": "Bad Request"}', text='{"code": 400, "msg": "Bad Request"}',
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}), json=Mock(return_value={'code': 400, 'msg': 'Bad Request'})
) )
mock_client = Mock() mock_client = Mock()
@@ -491,7 +493,7 @@ class TestHTTPScenarios:
assert mock_app.logger.warning.call_count >= 1 assert mock_app.logger.warning.call_count >= 1
# Check that one of the calls contains application error info # Check that one of the calls contains application error info
all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list] all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list]
assert any('400' in w for w in all_warnings), f'No warning contained error code 400: {all_warnings}' assert any('400' in w for w in all_warnings), f"No warning contained error code 400: {all_warnings}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_timeout_logs_warning(self): async def test_send_timeout_logs_warning(self):

View File

@@ -9,7 +9,6 @@ Tests cover:
Note: Do NOT use 'from __future__ import annotations' because Note: Do NOT use 'from __future__ import annotations' because
funcschema.py expects actual type objects, not string annotations. funcschema.py expects actual type objects, not string annotations.
""" """
import pytest import pytest
from importlib import import_module from importlib import import_module

View File

@@ -20,53 +20,55 @@ class TestGetQQImageDownloadableUrl:
def test_basic_url(self): def test_basic_url(self):
"""Parse basic image URL.""" """Parse basic image URL."""
url = 'http://example.com/image.jpg' url = "http://example.com/image.jpg"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
assert result_url == 'http://example.com/image.jpg' assert result_url == "http://example.com/image.jpg"
assert query == {} assert query == {}
def test_url_with_query_params(self): def test_url_with_query_params(self):
"""Parse URL with query parameters.""" """Parse URL with query parameters."""
url = 'http://example.com/image.jpg?param1=value1&param2=value2' url = "http://example.com/image.jpg?param1=value1&param2=value2"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
assert result_url == 'http://example.com/image.jpg' assert result_url == "http://example.com/image.jpg"
assert query == {'param1': ['value1'], 'param2': ['value2']} assert query == {"param1": ["value1"], "param2": ["value2"]}
def test_url_with_port(self): def test_url_with_port(self):
"""Parse URL with port number.""" """Parse URL with port number."""
url = 'http://example.com:8080/image.jpg' url = "http://example.com:8080/image.jpg"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
assert result_url == 'http://example.com:8080/image.jpg' assert result_url == "http://example.com:8080/image.jpg"
def test_url_with_path(self): def test_url_with_path(self):
"""Parse URL with complex path.""" """Parse URL with complex path."""
url = 'http://example.com/path/to/image.jpg' url = "http://example.com/path/to/image.jpg"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
assert result_url == 'http://example.com/path/to/image.jpg' assert result_url == "http://example.com/path/to/image.jpg"
def test_url_with_fragment(self): def test_url_with_fragment(self):
"""Parse URL with fragment (fragment is not part of query).""" """Parse URL with fragment (fragment is not part of query)."""
url = 'http://example.com/image.jpg#fragment' url = "http://example.com/image.jpg#fragment"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
# Fragment is not included in query string parsing # Fragment is not included in query string parsing
assert 'http://example.com/image.jpg' in result_url assert "http://example.com/image.jpg" in result_url
def test_https_url(self): def test_https_url(self):
"""Parse HTTPS URL and preserve its scheme.""" """Parse HTTPS URL and preserve its scheme."""
url = 'https://example.com/image.jpg' url = "https://example.com/image.jpg"
result_url, query = get_qq_image_downloadable_url(url) result_url, query = get_qq_image_downloadable_url(url)
assert result_url == 'https://example.com/image.jpg' assert result_url == "https://example.com/image.jpg"
assert query == {} assert query == {}
def test_preserves_qq_https_scheme_and_query(self): def test_preserves_qq_https_scheme_and_query(self):
"""QQ image URLs keep HTTPS and query parameters.""" """QQ image URLs keep HTTPS and query parameters."""
result_url, query = get_qq_image_downloadable_url('https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1') result_url, query = get_qq_image_downloadable_url(
'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1'
)
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0' assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
assert query == {'term': ['2'], 'is_origin': ['1']} assert query == {'term': ['2'], 'is_origin': ['1']}
@@ -86,50 +88,50 @@ class TestExtractB64AndFormat:
async def test_jpeg_data_uri(self): async def test_jpeg_data_uri(self):
"""Extract base64 and format from JPEG data URI.""" """Extract base64 and format from JPEG data URI."""
# Create a simple base64 string # Create a simple base64 string
original_data = b'test image data' original_data = b"test image data"
b64_data = base64.b64encode(original_data).decode() b64_data = base64.b64encode(original_data).decode()
data_uri = f'data:image/jpeg;base64,{b64_data}' data_uri = f"data:image/jpeg;base64,{b64_data}"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data assert result_b64 == b64_data
assert result_format == 'jpeg' assert result_format == "jpeg"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_png_data_uri(self): async def test_png_data_uri(self):
"""Extract base64 and format from PNG data URI.""" """Extract base64 and format from PNG data URI."""
original_data = b'test png data' original_data = b"test png data"
b64_data = base64.b64encode(original_data).decode() b64_data = base64.b64encode(original_data).decode()
data_uri = f'data:image/png;base64,{b64_data}' data_uri = f"data:image/png;base64,{b64_data}"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data assert result_b64 == b64_data
assert result_format == 'png' assert result_format == "png"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gif_data_uri(self): async def test_gif_data_uri(self):
"""Extract base64 and format from GIF data URI.""" """Extract base64 and format from GIF data URI."""
original_data = b'test gif data' original_data = b"test gif data"
b64_data = base64.b64encode(original_data).decode() b64_data = base64.b64encode(original_data).decode()
data_uri = f'data:image/gif;base64,{b64_data}' data_uri = f"data:image/gif;base64,{b64_data}"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data assert result_b64 == b64_data
assert result_format == 'gif' assert result_format == "gif"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_webp_data_uri(self): async def test_webp_data_uri(self):
"""Extract base64 and format from WebP data URI.""" """Extract base64 and format from WebP data URI."""
original_data = b'test webp data' original_data = b"test webp data"
b64_data = base64.b64encode(original_data).decode() b64_data = base64.b64encode(original_data).decode()
data_uri = f'data:image/webp;base64,{b64_data}' data_uri = f"data:image/webp;base64,{b64_data}"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data assert result_b64 == b64_data
assert result_format == 'webp' assert result_format == "webp"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_complex_base64(self): async def test_complex_base64(self):
@@ -137,7 +139,7 @@ class TestExtractB64AndFormat:
# Base64 can include + and / characters # Base64 can include + and / characters
original_data = bytes(range(256)) # All byte values original_data = bytes(range(256)) # All byte values
b64_data = base64.b64encode(original_data).decode() b64_data = base64.b64encode(original_data).decode()
data_uri = f'data:image/png;base64,{b64_data}' data_uri = f"data:image/png;base64,{b64_data}"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
@@ -148,9 +150,9 @@ class TestExtractB64AndFormat:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_base64(self): async def test_empty_base64(self):
"""Handle empty base64 string.""" """Handle empty base64 string."""
data_uri = 'data:image/png;base64,' data_uri = "data:image/png;base64,"
result_b64, result_format = await extract_b64_and_format(data_uri) result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == '' assert result_b64 == ""
assert result_format == 'png' assert result_format == "png"

View File

@@ -23,52 +23,52 @@ class TestImportDir:
def test_calls_importlib_for_each_python_file(self, tmp_path): def test_calls_importlib_for_each_python_file(self, tmp_path):
"""Should call importlib.import_module for each .py file.""" """Should call importlib.import_module for each .py file."""
module_dir = tmp_path / 'test_modules' module_dir = tmp_path / "test_modules"
module_dir.mkdir() module_dir.mkdir()
(module_dir / '__init__.py').write_text('') (module_dir / "__init__.py").write_text("")
(module_dir / 'module_a.py').write_text("VALUE_A = 'a'\n") (module_dir / "module_a.py").write_text("VALUE_A = 'a'\n")
(module_dir / 'module_b.py').write_text("VALUE_B = 'b'\n") (module_dir / "module_b.py").write_text("VALUE_B = 'b'\n")
(module_dir / 'readme.txt').write_text('not a module') (module_dir / "readme.txt").write_text("not a module")
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with patch.object(importlib, 'import_module') as mock_import: with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.') importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
# Should call import_module for each .py file (excluding __init__.py) # Should call import_module for each .py file (excluding __init__.py)
assert mock_import.call_count == 2 assert mock_import.call_count == 2
def test_skips_init_py(self, tmp_path): def test_skips_init_py(self, tmp_path):
"""Should skip __init__.py when importing.""" """Should skip __init__.py when importing."""
module_dir = tmp_path / 'test_modules' module_dir = tmp_path / "test_modules"
module_dir.mkdir() module_dir.mkdir()
(module_dir / '__init__.py').write_text('') (module_dir / "__init__.py").write_text("")
(module_dir / 'regular.py').write_text('VALUE = 1\n') (module_dir / "regular.py").write_text("VALUE = 1\n")
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with patch.object(importlib, 'import_module') as mock_import: with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.') importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
# __init__.py should be skipped # __init__.py should be skipped
mock_import.assert_called_once() mock_import.assert_called_once()
# The call should not include __init__ # The call should not include __init__
call_args = mock_import.call_args[0][0] call_args = mock_import.call_args[0][0]
assert '__init__' not in call_args assert "__init__" not in call_args
def test_ignores_non_py_files(self, tmp_path): def test_ignores_non_py_files(self, tmp_path):
"""Should ignore non-.py files.""" """Should ignore non-.py files."""
module_dir = tmp_path / 'test_modules' module_dir = tmp_path / "test_modules"
module_dir.mkdir() module_dir.mkdir()
(module_dir / 'module.py').write_text('VALUE = 1\n') (module_dir / "module.py").write_text("VALUE = 1\n")
(module_dir / 'readme.txt').write_text('text') (module_dir / "readme.txt").write_text("text")
(module_dir / 'data.json').write_text('{}') (module_dir / "data.json").write_text("{}")
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with patch.object(importlib, 'import_module') as mock_import: with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.') importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
# Only .py files should be imported # Only .py files should be imported
assert mock_import.call_count == 1 assert mock_import.call_count == 1
@@ -79,14 +79,14 @@ class TestImportModulesInPkg:
def test_imports_modules_from_package(self, tmp_path): def test_imports_modules_from_package(self, tmp_path):
"""Should import all modules from a package object.""" """Should import all modules from a package object."""
mock_pkg = MagicMock() mock_pkg = MagicMock()
mock_pkg.__file__ = str(tmp_path / '__init__.py') mock_pkg.__file__ = str(tmp_path / "__init__.py")
(tmp_path / '__init__.py').write_text('') (tmp_path / "__init__.py").write_text("")
(tmp_path / 'mod1.py').write_text('MOD1 = 1\n') (tmp_path / "mod1.py").write_text("MOD1 = 1\n")
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with patch.object(importutil, 'import_dir') as mock_import_dir: with patch.object(importutil, "import_dir") as mock_import_dir:
importutil.import_modules_in_pkg(mock_pkg) importutil.import_modules_in_pkg(mock_pkg)
mock_import_dir.assert_called_once() mock_import_dir.assert_called_once()
call_path = mock_import_dir.call_args[0][0] call_path = mock_import_dir.call_args[0][0]
@@ -101,11 +101,11 @@ class TestImportModulesInPkgs:
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
mock_pkg1 = MagicMock() mock_pkg1 = MagicMock()
mock_pkg1.__file__ = '/path/to/pkg1/__init__.py' mock_pkg1.__file__ = "/path/to/pkg1/__init__.py"
mock_pkg2 = MagicMock() mock_pkg2 = MagicMock()
mock_pkg2.__file__ = '/path/to/pkg2/__init__.py' mock_pkg2.__file__ = "/path/to/pkg2/__init__.py"
with patch.object(importutil, 'import_modules_in_pkg') as mock_import: with patch.object(importutil, "import_modules_in_pkg") as mock_import:
importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2]) importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2])
assert mock_import.call_count == 2 assert mock_import.call_count == 2
@@ -116,18 +116,18 @@ class TestImportDotStyleDir:
def test_converts_dot_notation_to_path(self, tmp_path): def test_converts_dot_notation_to_path(self, tmp_path):
"""Should convert dot notation to path and import.""" """Should convert dot notation to path and import."""
# Create structure matching the dot notation # Create structure matching the dot notation
(tmp_path / 'my').mkdir() (tmp_path / "my").mkdir()
(tmp_path / 'my' / 'pkg').mkdir() (tmp_path / "my" / "pkg").mkdir()
(tmp_path / 'my' / 'pkg' / 'test').mkdir() (tmp_path / "my" / "pkg" / "test").mkdir()
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with patch.object(importutil, 'import_dir') as mock_import_dir: with patch.object(importutil, "import_dir") as mock_import_dir:
importutil.import_dot_style_dir('my.pkg.test') importutil.import_dot_style_dir("my.pkg.test")
# The path should be converted using os.path.join # The path should be converted using os.path.join
call_path = mock_import_dir.call_args[0][0] call_path = mock_import_dir.call_args[0][0]
# Should contain the path components joined # Should contain the path components joined
assert 'my' in call_path assert "my" in call_path
class TestReadResourceFile: class TestReadResourceFile:
@@ -137,16 +137,16 @@ class TestReadResourceFile:
"""Should read content from a resource file.""" """Should read content from a resource file."""
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
content = importutil.read_resource_file('templates/config.yaml') content = importutil.read_resource_file("templates/config.yaml")
assert 'admins:' in content assert "admins:" in content
assert 'edition: community' in content assert "edition: community" in content
def test_raises_for_nonexistent_file(self): def test_raises_for_nonexistent_file(self):
"""Should raise exception for non-existent resource file.""" """Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)): with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file('nonexistent/path/file.txt') importutil.read_resource_file("nonexistent/path/file.txt")
class TestReadResourceFileBytes: class TestReadResourceFileBytes:
@@ -156,16 +156,16 @@ class TestReadResourceFileBytes:
"""Should read content as bytes from a resource file.""" """Should read content as bytes from a resource file."""
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
content = importutil.read_resource_file_bytes('templates/config.yaml') content = importutil.read_resource_file_bytes("templates/config.yaml")
assert b'admins:' in content assert b"admins:" in content
assert b'edition: community' in content assert b"edition: community" in content
def test_raises_for_nonexistent_file_bytes(self): def test_raises_for_nonexistent_file_bytes(self):
"""Should raise exception for non-existent resource file.""" """Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)): with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file_bytes('nonexistent/path/file.txt') importutil.read_resource_file_bytes("nonexistent/path/file.txt")
class TestListResourceFiles: class TestListResourceFiles:
@@ -175,9 +175,9 @@ class TestListResourceFiles:
"""Should list files in a resource directory.""" """Should list files in a resource directory."""
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
files = importutil.list_resource_files('templates') files = importutil.list_resource_files("templates")
assert 'config.yaml' in files assert "config.yaml" in files
assert 'default-pipeline-config.json' in files assert "default-pipeline-config.json" in files
assert all(isinstance(file, str) for file in files) assert all(isinstance(file, str) for file in files)
def test_raises_for_nonexistent_directory(self): def test_raises_for_nonexistent_directory(self):
@@ -185,8 +185,8 @@ class TestListResourceFiles:
from langbot.pkg.utils import importutil from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)): with pytest.raises((FileNotFoundError, Exception)):
importutil.list_resource_files('nonexistent_directory_xyz') importutil.list_resource_files("nonexistent_directory_xyz")
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__, '-v']) pytest.main([__file__, "-v"])

View File

@@ -5,7 +5,6 @@ Tests cover:
- Docker environment detection - Docker environment detection
- WebSocket plugin runtime mode - WebSocket plugin runtime mode
""" """
from __future__ import annotations from __future__ import annotations
import os import os
@@ -87,4 +86,4 @@ class TestGetPlatform:
assert platform_module.use_websocket_to_connect_plugin_runtime() is True assert platform_module.use_websocket_to_connect_plugin_runtime() is True
# Restore # Restore
platform_module.standalone_runtime = original platform_module.standalone_runtime = original

View File

@@ -60,12 +60,10 @@ class TestProxyManager:
async def test_initialize_config_overrides_env(self): async def test_initialize_config_overrides_env(self):
"""Config proxy overrides environment variables.""" """Config proxy overrides environment variables."""
mock_app = self._create_mock_app( mock_app = self._create_mock_app(proxy_config={
proxy_config={ 'http': 'http://config-proxy:8080',
'http': 'http://config-proxy:8080', 'https': 'https://config-proxy:8443',
'https': 'https://config-proxy:8443', })
}
)
with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}): with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}):
pm = ProxyManager(mock_app) pm = ProxyManager(mock_app)
@@ -76,12 +74,10 @@ class TestProxyManager:
async def test_initialize_sets_env_variables(self): async def test_initialize_sets_env_variables(self):
"""initialize sets proxy to environment variables.""" """initialize sets proxy to environment variables."""
mock_app = self._create_mock_app( mock_app = self._create_mock_app(proxy_config={
proxy_config={ 'http': 'http://test-proxy:8080',
'http': 'http://test-proxy:8080', 'https': 'https://test-proxy:8443',
'https': 'https://test-proxy:8443', })
}
)
pm = ProxyManager(mock_app) pm = ProxyManager(mock_app)
await pm.initialize() await pm.initialize()
@@ -147,11 +143,9 @@ class TestProxyManager:
async def test_initialize_http_only_config(self): async def test_initialize_http_only_config(self):
"""initialize handles http-only config.""" """initialize handles http-only config."""
mock_app = self._create_mock_app( mock_app = self._create_mock_app(proxy_config={
proxy_config={ 'http': 'http://http-only:8080',
'http': 'http://http-only:8080', })
}
)
# Clear any existing proxy env vars # Clear any existing proxy env vars
env_backup = {} env_backup = {}

View File

@@ -29,63 +29,63 @@ class TestGetRunnerCategory:
def test_empty_url_returns_unknown(self): def test_empty_url_returns_unknown(self):
"""Empty or None URL should return UNKNOWN.""" """Empty or None URL should return UNKNOWN."""
assert get_runner_category('test', '') == RunnerCategory.UNKNOWN assert get_runner_category("test", "") == RunnerCategory.UNKNOWN
assert get_runner_category('test', None) == RunnerCategory.UNKNOWN assert get_runner_category("test", None) == RunnerCategory.UNKNOWN
def test_localhost_returns_local(self): def test_localhost_returns_local(self):
"""localhost URL should be categorized as LOCAL.""" """localhost URL should be categorized as LOCAL."""
assert get_runner_category('test', 'http://localhost:3000') == RunnerCategory.LOCAL assert get_runner_category("test", "http://localhost:3000") == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://localhost') == RunnerCategory.LOCAL assert get_runner_category("test", "https://localhost") == RunnerCategory.LOCAL
def test_127_0_0_1_returns_local(self): def test_127_0_0_1_returns_local(self):
"""127.0.0.1 URL should be categorized as LOCAL.""" """127.0.0.1 URL should be categorized as LOCAL."""
assert get_runner_category('test', 'http://127.0.0.1:8080') == RunnerCategory.LOCAL assert get_runner_category("test", "http://127.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://127.0.0.1') == RunnerCategory.LOCAL assert get_runner_category("test", "https://127.0.0.1") == RunnerCategory.LOCAL
def test_0_0_0_0_returns_local(self): def test_0_0_0_0_returns_local(self):
"""0.0.0.0 URL should be categorized as LOCAL.""" """0.0.0.0 URL should be categorized as LOCAL."""
assert get_runner_category('test', 'http://0.0.0.0:8080') == RunnerCategory.LOCAL assert get_runner_category("test", "http://0.0.0.0:8080") == RunnerCategory.LOCAL
def test_private_ip_192_168_returns_local(self): def test_private_ip_192_168_returns_local(self):
"""192.168.x.x private IP should be categorized as LOCAL.""" """192.168.x.x private IP should be categorized as LOCAL."""
assert get_runner_category('test', 'http://192.168.1.1:3000') == RunnerCategory.LOCAL assert get_runner_category("test", "http://192.168.1.1:3000") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://192.168.0.100') == RunnerCategory.LOCAL assert get_runner_category("test", "http://192.168.0.100") == RunnerCategory.LOCAL
def test_private_ip_10_returns_local(self): def test_private_ip_10_returns_local(self):
"""10.x.x.x private IP should be categorized as LOCAL.""" """10.x.x.x private IP should be categorized as LOCAL."""
assert get_runner_category('test', 'http://10.0.0.1:8080') == RunnerCategory.LOCAL assert get_runner_category("test", "http://10.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://10.255.255.255') == RunnerCategory.LOCAL assert get_runner_category("test", "http://10.255.255.255") == RunnerCategory.LOCAL
def test_private_ip_172_16_31_returns_local(self): def test_private_ip_172_16_31_returns_local(self):
"""172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL.""" """172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL."""
assert get_runner_category('test', 'http://172.16.0.1:8080') == RunnerCategory.LOCAL assert get_runner_category("test", "http://172.16.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.20.0.1') == RunnerCategory.LOCAL assert get_runner_category("test", "http://172.20.0.1") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.31.255.255') == RunnerCategory.LOCAL assert get_runner_category("test", "http://172.31.255.255") == RunnerCategory.LOCAL
def test_n8n_cloud_returns_cloud(self): def test_n8n_cloud_returns_cloud(self):
"""n8n.cloud domain should be categorized as CLOUD.""" """n8n.cloud domain should be categorized as CLOUD."""
assert get_runner_category('test', 'https://myinstance.n8n.cloud') == RunnerCategory.CLOUD assert get_runner_category("test", "https://myinstance.n8n.cloud") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.n8n.io') == RunnerCategory.CLOUD assert get_runner_category("test", "https://test.n8n.io") == RunnerCategory.CLOUD
def test_dify_cloud_returns_cloud(self): def test_dify_cloud_returns_cloud(self):
"""Dify cloud domains should be categorized as CLOUD.""" """Dify cloud domains should be categorized as CLOUD."""
assert get_runner_category('test', 'https://api.dify.ai/v1') == RunnerCategory.CLOUD assert get_runner_category("test", "https://api.dify.ai/v1") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://cloud.dify.ai') == RunnerCategory.CLOUD assert get_runner_category("test", "https://cloud.dify.ai") == RunnerCategory.CLOUD
def test_coze_cloud_returns_cloud(self): def test_coze_cloud_returns_cloud(self):
"""Coze domains should be categorized as CLOUD.""" """Coze domains should be categorized as CLOUD."""
assert get_runner_category('test', 'https://api.coze.com') == RunnerCategory.CLOUD assert get_runner_category("test", "https://api.coze.com") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.coze.cn') == RunnerCategory.CLOUD assert get_runner_category("test", "https://api.coze.cn") == RunnerCategory.CLOUD
def test_langflow_cloud_returns_cloud(self): def test_langflow_cloud_returns_cloud(self):
"""Langflow domains should be categorized as CLOUD.""" """Langflow domains should be categorized as CLOUD."""
assert get_runner_category('test', 'https://cloud.langflow.ai') == RunnerCategory.CLOUD assert get_runner_category("test", "https://cloud.langflow.ai") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.langflow.org') == RunnerCategory.CLOUD assert get_runner_category("test", "https://test.langflow.org") == RunnerCategory.CLOUD
def test_other_url_returns_cloud(self): def test_other_url_returns_cloud(self):
"""Other URLs should default to CLOUD category.""" """Other URLs should default to CLOUD category."""
assert get_runner_category('test', 'https://example.com') == RunnerCategory.CLOUD assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://myserver.example.org') == RunnerCategory.CLOUD assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD
@pytest.mark.parametrize( @pytest.mark.parametrize(
'runner_url', 'runner_url',
@@ -101,7 +101,7 @@ class TestGetRunnerCategory:
) )
def test_invalid_urls_return_unknown(self, runner_url): def test_invalid_urls_return_unknown(self, runner_url):
"""Invalid or incomplete URLs should return UNKNOWN.""" """Invalid or incomplete URLs should return UNKNOWN."""
assert get_runner_category('test', runner_url) == RunnerCategory.UNKNOWN assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
def test_urlparse_exception_returns_unknown(self): def test_urlparse_exception_returns_unknown(self):
"""Exception during URL parsing should return UNKNOWN.""" """Exception during URL parsing should return UNKNOWN."""
@@ -109,15 +109,15 @@ class TestGetRunnerCategory:
from langbot.pkg.utils import runner from langbot.pkg.utils import runner
def mock_urlparse(url): def mock_urlparse(url):
raise Exception('URL parsing failed') raise Exception("URL parsing failed")
with patch('langbot.pkg.utils.runner.urlparse', side_effect=mock_urlparse): with patch("langbot.pkg.utils.runner.urlparse", side_effect=mock_urlparse):
result = runner.get_runner_category('test', 'http://example.com') result = runner.get_runner_category("test", "http://example.com")
assert result == RunnerCategory.UNKNOWN assert result == RunnerCategory.UNKNOWN
def test_url_without_scheme_returns_unknown(self): def test_url_without_scheme_returns_unknown(self):
"""URL without scheme should return UNKNOWN.""" """URL without scheme should return UNKNOWN."""
assert get_runner_category('test', 'example.com') == RunnerCategory.UNKNOWN assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN
@pytest.mark.parametrize( @pytest.mark.parametrize(
'runner_url', 'runner_url',
@@ -146,21 +146,20 @@ class TestGetRunnerCategory:
"""Domain names that only look like private IP prefixes should not be LOCAL.""" """Domain names that only look like private IP prefixes should not be LOCAL."""
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
class TestIsCloudRunner: class TestIsCloudRunner:
"""Test is_cloud_runner helper function.""" """Test is_cloud_runner helper function."""
def test_cloud_runner_returns_true(self): def test_cloud_runner_returns_true(self):
"""Cloud URL should return True.""" """Cloud URL should return True."""
assert is_cloud_runner('test', 'https://api.dify.ai') is True assert is_cloud_runner("test", "https://api.dify.ai") is True
def test_local_runner_returns_false(self): def test_local_runner_returns_false(self):
"""Local URL should return False.""" """Local URL should return False."""
assert is_cloud_runner('test', 'http://localhost:3000') is False assert is_cloud_runner("test", "http://localhost:3000") is False
def test_unknown_returns_false(self): def test_unknown_returns_false(self):
"""Unknown category should return False.""" """Unknown category should return False."""
assert is_cloud_runner('test', None) is False assert is_cloud_runner("test", None) is False
class TestIsLocalRunner: class TestIsLocalRunner:
@@ -168,15 +167,15 @@ class TestIsLocalRunner:
def test_local_runner_returns_true(self): def test_local_runner_returns_true(self):
"""Local URL should return True.""" """Local URL should return True."""
assert is_local_runner('test', 'http://localhost:3000') is True assert is_local_runner("test", "http://localhost:3000") is True
def test_cloud_runner_returns_false(self): def test_cloud_runner_returns_false(self):
"""Cloud URL should return False.""" """Cloud URL should return False."""
assert is_local_runner('test', 'https://api.dify.ai') is False assert is_local_runner("test", "https://api.dify.ai") is False
def test_unknown_returns_false(self): def test_unknown_returns_false(self):
"""Unknown category should return False.""" """Unknown category should return False."""
assert is_local_runner('test', None) is False assert is_local_runner("test", None) is False
class TestGetRunnerInfo: class TestGetRunnerInfo:
@@ -184,17 +183,17 @@ class TestGetRunnerInfo:
def test_returns_dict_with_expected_keys(self): def test_returns_dict_with_expected_keys(self):
"""Should return dict with name, url, and category keys.""" """Should return dict with name, url, and category keys."""
info = get_runner_info('my-runner', 'http://localhost:3000') info = get_runner_info("my-runner", "http://localhost:3000")
assert 'name' in info assert "name" in info
assert 'url' in info assert "url" in info
assert 'category' in info assert "category" in info
def test_includes_correct_values(self): def test_includes_correct_values(self):
"""Should include correct values in dict.""" """Should include correct values in dict."""
info = get_runner_info('my-runner', 'http://localhost:3000') info = get_runner_info("my-runner", "http://localhost:3000")
assert info['name'] == 'my-runner' assert info["name"] == "my-runner"
assert info['url'] == 'http://localhost:3000' assert info["url"] == "http://localhost:3000"
assert info['category'] == RunnerCategory.LOCAL assert info["category"] == RunnerCategory.LOCAL
class TestExtractRunnerUrl: class TestExtractRunnerUrl:
@@ -204,58 +203,74 @@ class TestExtractRunnerUrl:
"""Should extract base-url from dify-service-api config.""" """Should extract base-url from dify-service-api config."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}} pipeline_config = {
url = extract_runner_url('dify-service-api', runner, pipeline_config) "ai": {
assert url == 'https://api.dify.ai' "dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
url = extract_runner_url("dify-service-api", runner, pipeline_config)
assert url == "https://api.dify.ai"
def test_n8n_service_api_extracts_url(self): def test_n8n_service_api_extracts_url(self):
"""Should extract webhook-url from n8n-service-api config.""" """Should extract webhook-url from n8n-service-api config."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {'ai': {'n8n-service-api': {'webhook-url': 'https://my.n8n.cloud/webhook'}}} pipeline_config = {
url = extract_runner_url('n8n-service-api', runner, pipeline_config) "ai": {
assert url == 'https://my.n8n.cloud/webhook' "n8n-service-api": {"webhook-url": "https://my.n8n.cloud/webhook"}
}
}
url = extract_runner_url("n8n-service-api", runner, pipeline_config)
assert url == "https://my.n8n.cloud/webhook"
def test_coze_api_extracts_url(self): def test_coze_api_extracts_url(self):
"""Should extract api-base from coze-api config.""" """Should extract api-base from coze-api config."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {'ai': {'coze-api': {'api-base': 'https://api.coze.com'}}} pipeline_config = {
url = extract_runner_url('coze-api', runner, pipeline_config) "ai": {
assert url == 'https://api.coze.com' "coze-api": {"api-base": "https://api.coze.com"}
}
}
url = extract_runner_url("coze-api", runner, pipeline_config)
assert url == "https://api.coze.com"
def test_langflow_api_extracts_url(self): def test_langflow_api_extracts_url(self):
"""Should extract base-url from langflow-api config.""" """Should extract base-url from langflow-api config."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {'ai': {'langflow-api': {'base-url': 'https://cloud.langflow.ai'}}} pipeline_config = {
url = extract_runner_url('langflow-api', runner, pipeline_config) "ai": {
assert url == 'https://cloud.langflow.ai' "langflow-api": {"base-url": "https://cloud.langflow.ai"}
}
}
url = extract_runner_url("langflow-api", runner, pipeline_config)
assert url == "https://cloud.langflow.ai"
def test_unknown_runner_returns_none(self): def test_unknown_runner_returns_none(self):
"""Unknown runner name should return None.""" """Unknown runner name should return None."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {} pipeline_config = {}
url = extract_runner_url('unknown-runner', runner, pipeline_config) url = extract_runner_url("unknown-runner", runner, pipeline_config)
assert url is None assert url is None
def test_none_runner_returns_none(self): def test_none_runner_returns_none(self):
"""None runner should return None.""" """None runner should return None."""
url = extract_runner_url('test', None, {}) url = extract_runner_url("test", None, {})
assert url is None assert url is None
def test_runner_without_pipeline_config_returns_none(self): def test_runner_without_pipeline_config_returns_none(self):
"""Runner without pipeline_config attribute should return None.""" """Runner without pipeline_config attribute should return None."""
runner = Mock(spec=[]) # Empty spec means no attributes runner = Mock(spec=[]) # Empty spec means no attributes
url = extract_runner_url('test', runner, {}) url = extract_runner_url("test", runner, {})
assert url is None assert url is None
def test_none_pipeline_config_returns_none(self): def test_none_pipeline_config_returns_none(self):
"""None pipeline_config should return None.""" """None pipeline_config should return None."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
url = extract_runner_url('dify-service-api', runner, None) url = extract_runner_url("dify-service-api", runner, None)
assert url is None assert url is None
def test_missing_ai_config_returns_none(self): def test_missing_ai_config_returns_none(self):
@@ -263,7 +278,7 @@ class TestExtractRunnerUrl:
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {} pipeline_config = {}
url = extract_runner_url('dify-service-api', runner, pipeline_config) url = extract_runner_url("dify-service-api", runner, pipeline_config)
assert url is None assert url is None
@@ -274,15 +289,19 @@ class TestGetRunnerCategoryFromRunner:
"""Should extract URL and return correct category.""" """Should extract URL and return correct category."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}} pipeline_config = {
category = get_runner_category_from_runner('dify-service-api', runner, pipeline_config) "ai": {
"dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
category = get_runner_category_from_runner("dify-service-api", runner, pipeline_config)
assert category == RunnerCategory.CLOUD assert category == RunnerCategory.CLOUD
def test_returns_unknown_for_missing_url(self): def test_returns_unknown_for_missing_url(self):
"""Should return UNKNOWN when URL cannot be extracted.""" """Should return UNKNOWN when URL cannot be extracted."""
runner = Mock() runner = Mock()
runner.pipeline_config = {} runner.pipeline_config = {}
category = get_runner_category_from_runner('unknown', runner, {}) category = get_runner_category_from_runner("unknown", runner, {})
assert category == RunnerCategory.UNKNOWN assert category == RunnerCategory.UNKNOWN
@@ -291,9 +310,9 @@ class TestConstants:
def test_runner_category_constants(self): def test_runner_category_constants(self):
"""RunnerCategory should have LOCAL, CLOUD, UNKNOWN.""" """RunnerCategory should have LOCAL, CLOUD, UNKNOWN."""
assert RunnerCategory.LOCAL == 'local' assert RunnerCategory.LOCAL == "local"
assert RunnerCategory.CLOUD == 'cloud' assert RunnerCategory.CLOUD == "cloud"
assert RunnerCategory.UNKNOWN == 'unknown' assert RunnerCategory.UNKNOWN == "unknown"
def test_cloud_domains_not_empty(self): def test_cloud_domains_not_empty(self):
"""CLOUD_DOMAINS should not be empty.""" """CLOUD_DOMAINS should not be empty."""
@@ -304,5 +323,5 @@ class TestConstants:
assert len(LOCAL_PATTERNS) > 0 assert len(LOCAL_PATTERNS) > 0
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__, '-v']) pytest.main([__file__, "-v"])

View File

@@ -68,7 +68,11 @@ class TestNormalizeFilter:
def test_normalize_filter_multiple_conditions(self): def test_normalize_filter_multiple_conditions(self):
"""Multiple top-level keys are AND-ed (returned as multiple triples).""" """Multiple top-level keys are AND-ed (returned as multiple triples)."""
result = normalize_filter({'file_id': 'abc', 'status': {'$ne': 'deleted'}, 'created_at': {'$gte': 1700000000}}) result = normalize_filter({
'file_id': 'abc',
'status': {'$ne': 'deleted'},
'created_at': {'$gte': 1700000000}
})
assert len(result) == 3 assert len(result) == 3
# Order should match dict iteration order # Order should match dict iteration order
@@ -145,7 +149,11 @@ class TestStripUnsupportedFields:
('file_id', '$eq', 'def'), ('file_id', '$eq', 'def'),
] ]
result = strip_unsupported_fields(triples, {'file_id', 'chunk_uuid'}, field_aliases={'uuid': 'chunk_uuid'}) result = strip_unsupported_fields(
triples,
{'file_id', 'chunk_uuid'},
field_aliases={'uuid': 'chunk_uuid'}
)
assert len(result) == 2 assert len(result) == 2
# 'uuid' should be resolved to 'chunk_uuid' # 'uuid' should be resolved to 'chunk_uuid'
@@ -161,7 +169,7 @@ class TestStripUnsupportedFields:
result = strip_unsupported_fields( result = strip_unsupported_fields(
triples, triples,
{'file_id'}, # chunk_uuid not supported {'file_id'}, # chunk_uuid not supported
field_aliases={'uuid': 'chunk_uuid'}, field_aliases={'uuid': 'chunk_uuid'}
) )
assert result == [] assert result == []
@@ -199,5 +207,4 @@ class TestSupportedOpsConstant:
def test_supported_ops_is_frozenset(self): def test_supported_ops_is_frozenset(self):
"""SUPPORTED_OPS is a frozenset for immutability.""" """SUPPORTED_OPS is a frozenset for immutability."""
from collections.abc import Set from collections.abc import Set
assert isinstance(SUPPORTED_OPS, Set)
assert isinstance(SUPPORTED_OPS, Set)

View File

@@ -55,7 +55,6 @@ class TestVectorDBManagerInitialization:
# Run initialize synchronously for test # Run initialize synchronously for test
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
# Chroma should be instantiated # Chroma should be instantiated
@@ -77,7 +76,6 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_chroma_class.assert_called_once_with(mock_app) mock_chroma_class.assert_called_once_with(mock_app)
@@ -98,7 +96,6 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_qdrant_class.assert_called_once_with(mock_app) mock_qdrant_class.assert_called_once_with(mock_app)
@@ -118,7 +115,6 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_seekdb_class.assert_called_once_with(mock_app) mock_seekdb_class.assert_called_once_with(mock_app)
@@ -127,7 +123,11 @@ class TestVectorDBManagerInitialization:
"""Milvus config with custom URI.""" """Milvus config with custom URI."""
vdb_config = { vdb_config = {
'use': 'milvus', 'use': 'milvus',
'milvus': {'uri': 'http://localhost:19530', 'token': 'root:Milvus', 'db_name': 'langbot_db'}, 'milvus': {
'uri': 'http://localhost:19530',
'token': 'root:Milvus',
'db_name': 'langbot_db'
}
} }
mock_app = self._create_mock_app(vdb_config) mock_app = self._create_mock_app(vdb_config)
@@ -141,11 +141,13 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_milvus_class.assert_called_once_with( mock_milvus_class.assert_called_once_with(
mock_app, uri='http://localhost:19530', token='root:Milvus', db_name='langbot_db' mock_app,
uri='http://localhost:19530',
token='root:Milvus',
db_name='langbot_db'
) )
def test_initialize_milvus_backend_defaults(self): def test_initialize_milvus_backend_defaults(self):
@@ -163,45 +165,23 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
# Should use default values # Should use default values
mock_milvus_class.assert_called_once_with(mock_app, uri='./data/milvus.db', token=None, db_name='default') mock_milvus_class.assert_called_once_with(
mock_app,
uri='./data/milvus.db',
token=None,
db_name='default'
)
def test_initialize_pgvector_with_connection_string(self): def test_initialize_pgvector_with_connection_string(self):
"""pgvector with connection string.""" """pgvector with connection string."""
vdb_config = {'use': 'pgvector', 'pgvector': {'connection_string': 'postgresql://user:pass@host:5432/langbot'}}
mock_app = self._create_mock_app(vdb_config)
mocks = self._make_vector_import_mocks()
mock_pgvector_class = MagicMock()
mocks['langbot.pkg.vector.vdbs.pgvector_db'].PgVectorDatabase = mock_pgvector_class
with isolated_sys_modules(mocks):
from langbot.pkg.vector.mgr import VectorDBManager
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with(
mock_app, connection_string='postgresql://user:pass@host:5432/langbot'
)
def test_initialize_pgvector_with_individual_params(self):
"""pgvector with individual connection parameters."""
vdb_config = { vdb_config = {
'use': 'pgvector', 'use': 'pgvector',
'pgvector': { 'pgvector': {
'host': 'db.example.com', 'connection_string': 'postgresql://user:pass@host:5432/langbot'
'port': 5433, }
'database': 'vectordb',
'user': 'admin',
'password': 'secret',
},
} }
mock_app = self._create_mock_app(vdb_config) mock_app = self._create_mock_app(vdb_config)
@@ -215,11 +195,46 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with( mock_pgvector_class.assert_called_once_with(
mock_app, host='db.example.com', port=5433, database='vectordb', user='admin', password='secret' mock_app,
connection_string='postgresql://user:pass@host:5432/langbot'
)
def test_initialize_pgvector_with_individual_params(self):
"""pgvector with individual connection parameters."""
vdb_config = {
'use': 'pgvector',
'pgvector': {
'host': 'db.example.com',
'port': 5433,
'database': 'vectordb',
'user': 'admin',
'password': 'secret'
}
}
mock_app = self._create_mock_app(vdb_config)
mocks = self._make_vector_import_mocks()
mock_pgvector_class = MagicMock()
mocks['langbot.pkg.vector.vdbs.pgvector_db'].PgVectorDatabase = mock_pgvector_class
with isolated_sys_modules(mocks):
from langbot.pkg.vector.mgr import VectorDBManager
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with(
mock_app,
host='db.example.com',
port=5433,
database='vectordb',
user='admin',
password='secret'
) )
def test_initialize_pgvector_defaults(self): def test_initialize_pgvector_defaults(self):
@@ -237,11 +252,15 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with( mock_pgvector_class.assert_called_once_with(
mock_app, host='localhost', port=5432, database='langbot', user='postgres', password='postgres' mock_app,
host='localhost',
port=5432,
database='langbot',
user='postgres',
password='postgres'
) )
def test_initialize_unknown_backend_defaults_to_chroma(self): def test_initialize_unknown_backend_defaults_to_chroma(self):
@@ -259,7 +278,6 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app) mgr = VectorDBManager(mock_app)
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize()) asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_chroma_class.assert_called_once_with(mock_app) mock_chroma_class.assert_called_once_with(mock_app)
@@ -317,4 +335,4 @@ class TestVectorDBManagerProxies:
mgr.vector_db = mock_vector_db mgr.vector_db = mock_vector_db
result = mgr.get_supported_search_types() result = mgr.get_supported_search_types()
assert result == ['vector', 'full_text'] assert result == ['vector', 'full_text']

View File

@@ -39,7 +39,6 @@ class TestVectorDatabaseAbstractMethods:
def test_abstract_methods_required(self): def test_abstract_methods_required(self):
"""Subclass must implement all abstract methods.""" """Subclass must implement all abstract methods."""
class IncompleteVectorDB(VectorDatabase): class IncompleteVectorDB(VectorDatabase):
pass pass
@@ -48,21 +47,11 @@ class TestVectorDatabaseAbstractMethods:
def test_supported_search_types_default(self): def test_supported_search_types_default(self):
"""Default supported_search_types returns [VECTOR].""" """Default supported_search_types returns [VECTOR]."""
class MinimalVectorDB(VectorDatabase): class MinimalVectorDB(VectorDatabase):
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass pass
async def search( async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass pass
async def delete_by_file_id(self, collection, file_id): async def delete_by_file_id(self, collection, file_id):
@@ -82,21 +71,11 @@ class TestVectorDatabaseAbstractMethods:
def test_list_by_filter_default_implementation(self): def test_list_by_filter_default_implementation(self):
"""list_by_filter has default implementation returning empty.""" """list_by_filter has default implementation returning empty."""
class MinimalVectorDB(VectorDatabase): class MinimalVectorDB(VectorDatabase):
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass pass
async def search( async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass pass
async def delete_by_file_id(self, collection, file_id): async def delete_by_file_id(self, collection, file_id):
@@ -114,8 +93,9 @@ class TestVectorDatabaseAbstractMethods:
db = MinimalVectorDB() db = MinimalVectorDB()
# list_by_filter should return empty list and -1 for total # list_by_filter should return empty list and -1 for total
import asyncio import asyncio
result = asyncio.get_event_loop().run_until_complete(
result = asyncio.get_event_loop().run_until_complete(db.list_by_filter('test_collection')) db.list_by_filter('test_collection')
)
assert result == ([], -1) assert result == ([], -1)
@@ -125,17 +105,14 @@ class TestVectorDatabaseInterface:
@pytest.fixture @pytest.fixture
def mock_vector_db(self): def mock_vector_db(self):
"""Create a minimal mock VectorDatabase for testing.""" """Create a minimal mock VectorDatabase for testing."""
class MockVectorDB(VectorDatabase): class MockVectorDB(VectorDatabase):
def __init__(self): def __init__(self):
self.add_embeddings = AsyncMock() self.add_embeddings = AsyncMock()
self.search = AsyncMock( self.search = AsyncMock(return_value={
return_value={ 'ids': [['id1', 'id2']],
'ids': [['id1', 'id2']], 'distances': [[0.1, 0.2]],
'distances': [[0.1, 0.2]], 'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]]
'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]], })
}
)
self.delete_by_file_id = AsyncMock() self.delete_by_file_id = AsyncMock()
self.delete_by_filter = AsyncMock(return_value=5) self.delete_by_filter = AsyncMock(return_value=5)
self.get_or_create_collection = AsyncMock() self.get_or_create_collection = AsyncMock()
@@ -144,16 +121,7 @@ class TestVectorDatabaseInterface:
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass pass
async def search( async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass pass
async def delete_by_file_id(self, collection, file_id): async def delete_by_file_id(self, collection, file_id):
@@ -178,7 +146,7 @@ class TestVectorDatabaseInterface:
ids=['id1', 'id2'], ids=['id1', 'id2'],
embeddings_list=[[0.1, 0.2], [0.3, 0.4]], embeddings_list=[[0.1, 0.2], [0.3, 0.4]],
metadatas=[{'a': 1}, {'b': 2}], metadatas=[{'a': 1}, {'b': 2}],
documents=['doc1', 'doc2'], documents=['doc1', 'doc2']
) )
mock_vector_db.add_embeddings.assert_called_once() mock_vector_db.add_embeddings.assert_called_once()
@@ -194,7 +162,7 @@ class TestVectorDatabaseInterface:
search_type='hybrid', search_type='hybrid',
query_text='search text', query_text='search text',
filter={'file_id': 'abc'}, filter={'file_id': 'abc'},
vector_weight=0.7, vector_weight=0.7
) )
mock_vector_db.search.assert_called_once() mock_vector_db.search.assert_called_once()
@@ -202,4 +170,4 @@ class TestVectorDatabaseInterface:
async def test_delete_by_filter_returns_int(self, mock_vector_db): async def test_delete_by_filter_returns_int(self, mock_vector_db):
"""delete_by_filter returns int count.""" """delete_by_filter returns int count."""
result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'}) result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'})
assert isinstance(result, int) assert isinstance(result, int)

View File

@@ -5,7 +5,6 @@ Tests cover:
- _build_milvus_expr: Milvus boolean expression string conversion - _build_milvus_expr: Milvus boolean expression string conversion
- _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion - _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion
""" """
from __future__ import annotations from __future__ import annotations
from importlib import import_module from importlib import import_module
@@ -123,13 +122,11 @@ class TestQdrantFilterConversion:
"""Multiple conditions are combined in must/must_not.""" """Multiple conditions are combined in must/must_not."""
qdrant_module = get_qdrant_module() qdrant_module = get_qdrant_module()
result = qdrant_module._build_qdrant_filter( result = qdrant_module._build_qdrant_filter({
{ 'file_id': 'abc',
'file_id': 'abc', 'status': {'$ne': 'deleted'},
'status': {'$ne': 'deleted'}, 'created_at': {'$gte': 100},
'created_at': {'$gte': 100}, })
}
)
assert len(result.must) == 2 # file_id eq + created_at gte assert len(result.must) == 2 # file_id eq + created_at gte
assert len(result.must_not) == 1 # status ne assert len(result.must_not) == 1 # status ne
@@ -201,12 +198,10 @@ class TestMilvusFilterConversion:
"""Multiple conditions are joined with 'and'.""" """Multiple conditions are joined with 'and'."""
milvus_module = get_milvus_module() milvus_module = get_milvus_module()
result = milvus_module._build_milvus_expr( result = milvus_module._build_milvus_expr({
{ 'file_id': 'abc',
'file_id': 'abc', 'chunk_uuid': {'$ne': 'def'},
'chunk_uuid': {'$ne': 'def'}, })
}
)
assert 'and' in result assert 'and' in result
assert 'file_id == "abc"' in result assert 'file_id == "abc"' in result
assert 'chunk_uuid != "def"' in result assert 'chunk_uuid != "def"' in result
@@ -277,7 +272,6 @@ class TestPgVectorFilterConversion:
assert len(result) == 1 assert len(result) == 1
# Verify it's a SQLAlchemy BinaryExpression # Verify it's a SQLAlchemy BinaryExpression
from sqlalchemy.sql.expression import BinaryExpression from sqlalchemy.sql.expression import BinaryExpression
assert isinstance(result[0], BinaryExpression) assert isinstance(result[0], BinaryExpression)
def test_ne_operator_creates_inequality_condition(self): def test_ne_operator_creates_inequality_condition(self):
@@ -327,12 +321,10 @@ class TestPgVectorFilterConversion:
"""Multiple conditions return list of conditions.""" """Multiple conditions return list of conditions."""
pgvector_module = get_pgvector_module() pgvector_module = get_pgvector_module()
result = pgvector_module._build_pg_conditions( result = pgvector_module._build_pg_conditions({
{ 'file_id': 'abc',
'file_id': 'abc', 'chunk_uuid': {'$ne': 'def'},
'chunk_uuid': {'$ne': 'def'}, })
}
)
assert len(result) == 2 assert len(result) == 2
@@ -357,13 +349,11 @@ class TestPgVectorFilterConversion:
"""Only supported fields (text, file_id, chunk_uuid) are kept.""" """Only supported fields (text, file_id, chunk_uuid) are kept."""
pgvector_module = get_pgvector_module() pgvector_module = get_pgvector_module()
result = pgvector_module._build_pg_conditions( result = pgvector_module._build_pg_conditions({
{ 'text': {'$ne': ''},
'text': {'$ne': ''}, 'file_id': 'abc',
'file_id': 'abc', 'chunk_uuid': {'$in': ['x', 'y']},
'chunk_uuid': {'$in': ['x', 'y']}, 'unsupported': 'value',
'unsupported': 'value', })
}
)
assert len(result) == 3 # Only supported fields assert len(result) == 3 # Only supported fields

View File

@@ -1,3 +1,3 @@
""" """
Test utilities package. Test utilities package.
""" """

View File

@@ -26,7 +26,6 @@ from unittest.mock import MagicMock
class MockLifecycleControlScope(enum.Enum): class MockLifecycleControlScope(enum.Enum):
"""Mock enum for breaking circular import in core.entities.""" """Mock enum for breaking circular import in core.entities."""
APPLICATION = 'application' APPLICATION = 'application'
PLATFORM = 'platform' PLATFORM = 'platform'
PLUGIN = 'plugin' PLUGIN = 'plugin'
@@ -191,4 +190,4 @@ def get_handler_modules_to_clear(handler_name: str) -> list[str]:
'langbot.pkg.pipeline.process.handler', 'langbot.pkg.pipeline.process.handler',
'langbot.pkg.pipeline.process.handlers', 'langbot.pkg.pipeline.process.handlers',
f'langbot.pkg.pipeline.process.handlers.{handler_name}', f'langbot.pkg.pipeline.process.handlers.{handler_name}',
] ]

2
web/.gitignore vendored
View File

@@ -12,8 +12,6 @@
# testing # testing
/coverage /coverage
/playwright-report
/test-results
# next.js # next.js
/dist/ /dist/

View File

@@ -1,13 +1,3 @@
# Debug LangBot Frontend # Debug LangBot Frontend
Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information. Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information.
## Tests
Run the frontend smoke tests without a backend process:
```bash
pnpm test:e2e
```
The Playwright suite starts Vite and mocks the LangBot backend and Space APIs.

View File

@@ -6,7 +6,6 @@
"dev": "vite", "dev": "vite",
"build": "tsc && vite build", "build": "tsc && vite build",
"preview": "vite preview", "preview": "vite preview",
"test:e2e": "playwright test",
"lint": "eslint .", "lint": "eslint .",
"format": "prettier --write ." "format": "prettier --write ."
}, },
@@ -87,7 +86,6 @@
"zod": "^3.24.4" "zod": "^3.24.4"
}, },
"devDependencies": { "devDependencies": {
"@playwright/test": "^1.61.0",
"@types/debug": "^4.1.12", "@types/debug": "^4.1.12",
"@types/estree": "^1.0.8", "@types/estree": "^1.0.8",
"@types/estree-jsx": "^1.0.5", "@types/estree-jsx": "^1.0.5",

Some files were not shown because too many files have changed in this diff Show More