Compare commits

...

2 Commits

Author SHA1 Message Date
huanghuoguoguo
ff0c5a6f0a test: format test suite 2026-06-16 11:13:05 +08:00
huanghuoguoguo
1ae5aacc00 test: add frontend smoke and backend e2e CI (#2251) 2026-06-16 11:09:55 +08:00
105 changed files with 2492 additions and 1723 deletions

46
.github/workflows/frontend-tests.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Frontend Tests
on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
push:
branches:
- master
- develop
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
jobs:
playwright-smoke:
name: Playwright Smoke
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 8.9.2
- name: Install dependencies
working-directory: web
run: pnpm install --frozen-lockfile
- name: Install Playwright browsers
working-directory: web
run: pnpm exec playwright install --with-deps chromium
- name: Run Playwright smoke tests
working-directory: web
run: pnpm test:e2e

View File

@@ -29,7 +29,7 @@ jobs:
run: uv sync --dev
- name: Run ruff check
run: uv run ruff check src
run: uv run ruff check src/langbot/ tests/ --output-format=concise
- name: Run ruff format
run: uv run ruff format src --check

View File

@@ -84,6 +84,67 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
e2e:
name: E2E Startup Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run E2E startup tests
run: uv run pytest tests/e2e -q --tb=short
- name: E2E Test Summary
if: always()
run: |
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
box-integration:
name: Box Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Check Docker runtime
run: docker info
- name: Run Box integration tests
run: uv run pytest tests/integration_tests -q --tb=short
- name: Box Integration Test Summary
if: always()
run: |
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage:
name: Coverage Gate
runs-on: ubuntu-latest
@@ -129,4 +190,4 @@ jobs:
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

View File

@@ -1,6 +1,7 @@
# LangBot Test Suite
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
This directory contains the LangBot backend test suite, including unit tests,
integration tests, startup E2E tests, and container-backed Box runtime tests.
## Quality Gate Layers
@@ -10,10 +11,15 @@ LangBot uses a layered quality gate system for developers and CI:
|-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
**Note**: PostgreSQL migration tests and slow tests are NOT in local default
gates. They run in separate CI workflows. Frontend Playwright tests live under
`web/tests/e2e` and are documented in `web/README.md`.
### Developer Workflow
@@ -28,6 +34,9 @@ make test-all-local
bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min
uv run --python 3.12 pytest tests/e2e -q --tb=short
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
cd web && pnpm test:e2e
```
### Coverage Baseline
@@ -70,6 +79,12 @@ tests/
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── e2e/ # Real LangBot startup E2E tests
│ ├── conftest.py
│ ├── test_startup.py
│ └── utils/
├── integration_tests/ # Container-backed integration tests
│ └── box/ # Box runtime and MCP process tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
@@ -303,6 +318,44 @@ These tests:
- Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys
### Running backend E2E startup tests
Backend E2E tests start a real LangBot process with a generated minimal
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
They do not require provider keys or external services.
```bash
uv run --python 3.12 pytest tests/e2e -q --tb=short
```
These tests verify startup orchestration, migrations, API route registration,
and the minimal no-LLM startup path. The E2E process manager disables ambient
proxy variables for subprocess startup and uses direct localhost HTTP clients,
so local proxy settings should not affect the health checks.
### Running Box integration tests
Box integration tests exercise the real sandbox runtime path, including command
execution, session persistence, managed process WebSocket attachment, and
cleanup behavior.
```bash
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
```
These tests require a working Docker or Podman runtime. In CI, the dedicated
Box integration job checks Docker availability before running the tests.
### Running frontend E2E tests
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
and mock the LangBot backend and Space APIs, so no backend process is required.
```bash
cd web
pnpm test:e2e
```
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
@@ -320,6 +373,9 @@ Tests are automatically run on:
- Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
they exercise process/container behavior instead of pure Python compatibility.
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
## Adding New Tests
@@ -406,4 +462,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
- [ ] Add E2E tests
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis
- [ ] Add property-based testing with Hypothesis

View File

@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0) as client:
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client:
yield client
@pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db'
return e2e_tmpdir / 'data' / 'langbot.db'

View File

@@ -38,12 +38,13 @@ class TestStartupFlow:
# System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, e2e_db_path):
def test_database_initialized(self, langbot_process, e2e_db_path):
"""Verify SQLite database was created and initialized."""
assert e2e_db_path.exists()
# Database should have some tables after migration
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
@@ -74,10 +75,13 @@ class TestStartupFlow:
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint."""
# First startup may allow initial setup
response = e2e_client.post('/api/v1/user/auth', json={
'username': 'admin',
'password': 'admin',
})
response = e2e_client.post(
'/api/v1/user/auth',
json={
'user': 'admin',
'password': 'admin',
},
)
# Response could be:
# - 200 if auth succeeds
@@ -94,9 +98,10 @@ class TestStartupStages:
# If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, e2e_db_path):
def test_migrations_applied(self, langbot_process, e2e_db_path):
"""Verify database migrations were applied."""
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()

View File

@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
for path in directories.values():
path.mkdir(parents=True, exist_ok=True)
return directories
return directories

View File

@@ -44,6 +44,17 @@ class LangBotProcess:
# Prepare environment
env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src')
for proxy_key in (
'HTTP_PROXY',
'HTTPS_PROXY',
'ALL_PROXY',
'http_proxy',
'https_proxy',
'all_proxy',
):
env.pop(proxy_key, None)
env['NO_PROXY'] = '127.0.0.1,localhost'
env['no_proxy'] = '127.0.0.1,localhost'
# Set API port via environment variable
env['API__PORT'] = str(self.port)
@@ -79,9 +90,11 @@ precision = 2
f.write(coveragerc_content)
cmd = [
'coverage', 'run',
'coverage',
'run',
'--rcfile=' + str(coveragerc_path),
'-m', 'langbot',
'-m',
'langbot',
]
else:
cmd = ['uv', 'run', 'python', '-m', 'langbot']
@@ -113,6 +126,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0,
follow_redirects=False,
trust_env=False,
)
if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}')
@@ -185,6 +200,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0,
follow_redirects=False,
trust_env=False,
)
return r.status_code == 200
except Exception:
@@ -201,4 +218,4 @@ def find_project_root() -> Path:
return parent
# Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build')
return Path('/home/glwuy/langbot-app/LangBot-test-build')

View File

@@ -58,45 +58,45 @@ from tests.factories.platform import (
__all__ = [
# App
"FakeApp",
"fake_app",
'FakeApp',
'fake_app',
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
'text_chain',
'group_text_chain',
'mention_chain',
'image_chain',
# Message events
"friend_message_event",
"group_message_event",
'friend_message_event',
'group_message_event',
# Mock adapters
"mock_adapter",
'mock_adapter',
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
"image_query",
"file_query",
"unsupported_query",
"voice_query",
"at_all_query",
"query_with_session",
"query_with_config",
'text_query',
'group_text_query',
'private_text_query',
'command_query',
'mention_query',
'empty_query',
'image_query',
'file_query',
'unsupported_query',
'voice_query',
'at_all_query',
'query_with_session',
'query_with_config',
# Provider
"FakeProvider",
"fake_provider",
"fake_provider_pong",
"fake_provider_timeout",
"fake_provider_auth_error",
"fake_provider_rate_limit",
"fake_provider_malformed",
"fake_model",
'FakeProvider',
'fake_provider',
'fake_provider_pong',
'fake_provider_timeout',
'fake_provider_auth_error',
'fake_provider_rate_limit',
'fake_provider_malformed',
'fake_model',
# Platform
"FakePlatform",
"fake_platform",
"fake_platform_with_streaming",
"fake_platform_with_failure",
"mock_platform_adapter",
]
'FakePlatform',
'fake_platform',
'fake_platform_with_streaming',
'fake_platform_with_failure',
'mock_platform_adapter',
]

View File

@@ -30,32 +30,36 @@ def _next_query_id() -> int:
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
def text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
return platform_message.MessageChain(
[
platform_message.At(target=target),
platform_message.Plain(text=f' {text}'),
]
)
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
@@ -66,13 +70,15 @@ def image_chain(
def command_chain(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=f'{prefix}{command}'),
]
)
# ============== Message Event Factories ==============
@@ -81,7 +87,7 @@ def command_chain(
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
@@ -90,7 +96,7 @@ def friend_message_event(
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -100,9 +106,9 @@ def friend_message_event(
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
@@ -117,7 +123,7 @@ def group_message_event(
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -152,36 +158,36 @@ def _base_query(
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'query_id': query_id,
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'message_chain': message_chain,
'message_event': message_event,
'adapter': adapter,
'pipeline_uuid': 'test-pipeline-uuid',
'bot_uuid': 'test-bot-uuid',
'pipeline_config': {
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
'session': None,
'prompt': None,
'messages': [],
'user_message': None,
'use_funcs': [],
'use_llm_model_uuid': None,
'variables': {},
'resp_messages': [],
'resp_message_chain': None,
'current_stage_name': None,
}
# Apply overrides
@@ -192,7 +198,7 @@ def _base_query(
def text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -212,7 +218,7 @@ def text_query(
def private_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -221,7 +227,7 @@ def private_text_query(
def group_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -242,8 +248,8 @@ def group_text_query(
def command_query(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -263,7 +269,7 @@ def command_query(
def mention_query(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
@@ -301,8 +307,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -322,9 +328,9 @@ def image_query(
def file_query(
url: str = "https://example.com/document.pdf",
name: str = "document.pdf",
text: str = "",
url: str = 'https://example.com/document.pdf',
name: str = 'document.pdf',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -348,8 +354,8 @@ def file_query(
def unsupported_query(
unsupported_type: str = "CustomComponent",
text: str = "",
unsupported_type: str = 'CustomComponent',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -358,7 +364,7 @@ def unsupported_query(
if text:
components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}'))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
@@ -374,7 +380,7 @@ def unsupported_query(
def query_with_session(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None,
**overrides,
@@ -389,7 +395,7 @@ def query_with_session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -398,7 +404,7 @@ def query_with_session(
def query_with_config(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None,
**overrides,
@@ -410,22 +416,22 @@ def query_with_config(
"""
if pipeline_config is None:
pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
}
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query(
url: str = "https://example.com/audio.mp3",
url: str = 'https://example.com/audio.mp3',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -448,7 +454,7 @@ def voice_query(
def at_all_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -456,7 +462,7 @@ def at_all_query(
"""Create a group query with @All mention."""
components = [
platform_message.AtAll(),
platform_message.Plain(text=f" {text}"),
platform_message.Plain(text=f' {text}'),
]
chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id)
@@ -469,4 +475,4 @@ def at_all_query(
sender_id=sender_id,
adapter=adapter,
**overrides,
)
)

View File

@@ -33,7 +33,7 @@ class FakePlatform:
def __init__(
self,
*,
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
raise_error: Exception = None,
):
@@ -48,16 +48,16 @@ class FakePlatform:
# Registered listeners
self._listeners: dict = {}
def raises(self, error: Exception) -> "FakePlatform":
def raises(self, error: Exception) -> 'FakePlatform':
"""Configure platform to raise an error on send."""
self._raise_error = error
return self
def send_failure(self) -> "FakePlatform":
def send_failure(self) -> 'FakePlatform':
"""Configure platform to simulate send failure."""
return self.raises(Exception("Platform send failure"))
return self.raises(Exception('Platform send failure'))
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
def supports_streaming(self, supported: bool = True) -> 'FakePlatform':
"""Configure whether streaming output is supported."""
self._stream_output_supported = supported
return self
@@ -89,7 +89,7 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event."""
sender = platform_entities.Friend(
@@ -97,11 +97,13 @@ class FakePlatform:
nickname=nickname,
remark=None,
)
chain = platform_message.MessageChain([
platform_message.Plain(text=text),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -111,9 +113,9 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
mention_bot: bool = False,
) -> platform_events.GroupMessage:
"""Create an inbound group message event.
@@ -142,12 +144,12 @@ class FakePlatform:
components = []
if mention_bot:
components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=' '))
components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -155,8 +157,8 @@ class FakePlatform:
def create_image_message(
self,
url: str = "https://example.com/image.png",
text: str = "",
url: str = 'https://example.com/image.png',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
is_group: bool = False,
group_id: typing.Union[int, str] = 99999,
@@ -169,12 +171,12 @@ class FakePlatform:
chain = platform_message.MessageChain(components)
if is_group:
return self.create_group_message("", sender_id, group_id=group_id)
return self.create_group_message('', sender_id, group_id=group_id)
# Replace chain
else:
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -192,12 +194,14 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "send",
"target_type": target_type,
"target_id": target_id,
"message": message,
})
self._outbound_messages.append(
{
'type': 'send',
'target_type': target_type,
'target_id': target_id,
'message': message,
}
)
async def reply_message(
self,
@@ -209,13 +213,15 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "reply",
"source_type": message_source.type,
"source": message_source,
"message": message,
"quote_origin": quote_origin,
})
self._outbound_messages.append(
{
'type': 'reply',
'source_type': message_source.type,
'source': message_source,
'message': message,
'quote_origin': quote_origin,
}
)
async def reply_message_chunk(
self,
@@ -229,15 +235,17 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_chunks.append({
"type": "reply_chunk",
"source_type": message_source.type,
"source": message_source,
"bot_message": bot_message,
"message": message,
"quote_origin": quote_origin,
"is_final": is_final,
})
self._outbound_chunks.append(
{
'type': 'reply_chunk',
'source_type': message_source.type,
'source': message_source,
'bot_message': bot_message,
'message': message,
'quote_origin': quote_origin,
'is_final': is_final,
}
)
async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported."""
@@ -295,7 +303,7 @@ class FakePlatform:
def fake_platform(
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
) -> FakePlatform:
"""Create a FakePlatform instance."""
@@ -328,9 +336,7 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported)
adapter._fake_platform = platform # Store for assertions
return adapter
return adapter

View File

@@ -27,51 +27,51 @@ class FakeProvider:
Does not require API keys.
"""
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
def __init__(
self,
*,
default_response: str = "fake response",
default_response: str = 'fake response',
streaming_chunks: list[str] = None,
raise_error: Exception = None,
captured_requests: list = None,
):
self._default_response = default_response
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> "FakeProvider":
def returns(self, text: str) -> 'FakeProvider':
"""Configure provider to return a specific text response."""
self._default_response = text
self._streaming_chunks = [text]
return self
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
"""Configure provider to return streaming chunks."""
self._streaming_chunks = chunks
self._default_response = "".join(chunks)
self._default_response = ''.join(chunks)
return self
def raises(self, error: Exception) -> "FakeProvider":
def raises(self, error: Exception) -> 'FakeProvider':
"""Configure provider to raise an error."""
self._raise_error = error
return self
def timeout(self) -> "FakeProvider":
def timeout(self) -> 'FakeProvider':
"""Configure provider to simulate timeout."""
return self.raises(TimeoutError("Provider timeout"))
return self.raises(TimeoutError('Provider timeout'))
def auth_error(self) -> "FakeProvider":
def auth_error(self) -> 'FakeProvider':
"""Configure provider to simulate auth error."""
return self.raises(Exception("Invalid API key"))
return self.raises(Exception('Invalid API key'))
def rate_limit(self) -> "FakeProvider":
def rate_limit(self) -> 'FakeProvider':
"""Configure provider to simulate rate limit."""
return self.raises(Exception("Rate limit exceeded"))
return self.raises(Exception('Rate limit exceeded'))
def malformed(self) -> "FakeProvider":
def malformed(self) -> 'FakeProvider':
"""Configure provider to simulate malformed response."""
self._default_response = None
return self
@@ -87,7 +87,7 @@ class FakeProvider:
def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content."""
return provider_message.Message(
role="assistant",
role='assistant',
content=content,
)
@@ -99,7 +99,7 @@ class FakeProvider:
) -> provider_message.MessageChunk:
"""Create a provider message chunk."""
return provider_message.MessageChunk(
role="assistant",
role='assistant',
content=content,
is_final=is_final,
msg_sequence=msg_sequence,
@@ -116,13 +116,15 @@ class FakeProvider:
) -> provider_message.Message:
"""Simulate non-streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
}
)
# Simulate error if configured
if self._raise_error:
@@ -131,7 +133,7 @@ class FakeProvider:
# Return response
if self._default_response is None:
# Malformed response
return provider_message.Message(role="assistant", content=None)
return provider_message.Message(role='assistant', content=None)
return self._create_message(self._default_response)
@@ -146,14 +148,16 @@ class FakeProvider:
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
"streaming": True,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
'streaming': True,
}
)
# Simulate error if configured
if self._raise_error:
@@ -161,12 +165,12 @@ class FakeProvider:
# Yield chunks
for i, chunk in enumerate(self._streaming_chunks):
is_final = (i == len(self._streaming_chunks) - 1)
is_final = i == len(self._streaming_chunks) - 1
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider(
default_response: str = "fake response",
default_response: str = 'fake response',
) -> FakeProvider:
"""Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response)
@@ -202,8 +206,8 @@ def fake_provider_malformed() -> FakeProvider:
def fake_model(
*,
uuid: str = "test-model-uuid",
name: str = "test-model",
uuid: str = 'test-model-uuid',
name: str = 'test-model',
abilities: list[str] = None,
provider: FakeProvider = None,
) -> Mock:
@@ -212,7 +216,7 @@ def fake_model(
model.model_entity = Mock()
model.model_entity.uuid = uuid
model.model_entity.name = name
model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.abilities = abilities or ['func_call', 'vision']
model.model_entity.extra_args = {}
# Attach fake provider
@@ -221,4 +225,4 @@ def fake_model(
model.provider = provider
return model
return model

View File

@@ -3,4 +3,4 @@ Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""
"""

View File

@@ -2,4 +2,4 @@
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""
"""

View File

@@ -48,6 +48,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield
@@ -56,10 +57,12 @@ def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -71,28 +74,29 @@ def fake_bot_app():
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[
{
app.bot_service.get_bots = AsyncMock(
return_value=[
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
]
)
app.bot_service.get_runtime_bot_info = AsyncMock(
return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
}
])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
)
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1))
app.bot_service.send_message = AsyncMock()
# Platform manager
@@ -118,10 +122,7 @@ class TestBotEndpoints:
@pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'})
assert response.status_code == 200
data = await response.get_json()
@@ -135,7 +136,7 @@ class TestBotEndpoints:
response = await quart_test_client.post(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'},
)
assert response.status_code == 200
@@ -147,8 +148,7 @@ class TestBotEndpoints:
async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -162,7 +162,7 @@ class TestBotEndpoints:
response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}
json={'name': 'Updated Bot'},
)
assert response.status_code == 200
@@ -173,8 +173,7 @@ class TestBotEndpoints:
async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -190,7 +189,7 @@ class TestBotLogsEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}
json={'from_index': -1, 'max_count': 10},
)
assert response.status_code == 200
@@ -213,8 +212,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'person',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 200
@@ -228,7 +227,7 @@ class TestBotSendMessageEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]},
)
assert response.status_code == 400
@@ -244,8 +243,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'invalid',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 400

View File

@@ -47,6 +47,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield
@@ -55,10 +56,12 @@ def fake_embed_app():
"""Create FakeApp with embed widget services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock()
@@ -83,9 +86,7 @@ def fake_embed_app():
# WebSocket proxy bot with adapter
mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}])
mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock()
@@ -117,9 +118,7 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js')
assert response.status_code == 200
assert 'javascript' in response.content_type
@@ -127,18 +126,14 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js')
assert response.status_code == 400
@pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js')
assert response.status_code == 404
@@ -164,8 +159,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 200
@@ -177,8 +171,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 400
@@ -187,8 +180,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={}
)
assert response.status_code == 400
@@ -203,7 +195,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/person returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -216,7 +208,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/group returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -226,7 +218,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages with invalid session_type returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 400
@@ -241,7 +233,7 @@ class TestEmbedResetEndpoint:
"""POST reset/person resets session."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -252,8 +244,7 @@ class TestEmbedResetEndpoint:
async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@@ -269,7 +260,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}
json={'message_id': 'msg-123', 'feedback_type': 1},
)
assert response.status_code == 200
@@ -283,7 +274,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}
json={'message_id': 'msg-123', 'feedback_type': 2},
)
assert response.status_code == 200
@@ -294,7 +285,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}
json={'message_id': 'msg-123', 'feedback_type': 99},
)
assert response.status_code == 400

View File

@@ -49,6 +49,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,33 +75,35 @@ def fake_knowledge_app():
# Knowledge service
app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
{
app.knowledge_service.get_knowledge_bases = AsyncMock(
return_value=[
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
]
)
app.knowledge_service.get_knowledge_base = AsyncMock(
return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
)
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
])
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}]
)
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}])
# RAG manager
app.rag_mgr = Mock()
@@ -124,8 +129,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -140,7 +144,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'},
)
assert response.status_code == 200
@@ -152,8 +156,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -167,7 +170,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}
json={'name': 'Updated KB'},
)
assert response.status_code == 200
@@ -178,8 +181,7 @@ class TestKnowledgeBaseEndpoints:
async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -193,8 +195,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -208,7 +209,7 @@ class TestKnowledgeBaseFilesEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'},
)
assert response.status_code == 200
@@ -220,8 +221,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}},
)
assert response.status_code == 200
@@ -249,9 +249,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={}
)
assert response.status_code == 400

View File

@@ -46,6 +46,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield
@@ -54,10 +55,12 @@ def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock()
@@ -67,40 +70,34 @@ def fake_monitoring_app():
# Monitoring service
app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
})
app.monitoring_service.get_messages = AsyncMock(return_value=(
[{'id': 'msg-1', 'content': 'test'}], 100
))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
[{'id': 'llm-1'}], 50
))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
[{'id': 'emb-1'}], 10
))
app.monitoring_service.get_sessions = AsyncMock(return_value=(
[{'session_id': 'sess-1'}], 20
))
app.monitoring_service.get_errors = AsyncMock(return_value=(
[{'id': 'err-1'}], 2
))
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'found': True,
'session_id': 'sess-1',
})
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_overview_metrics = AsyncMock(
return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
}
)
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10))
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20))
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2))
app.monitoring_service.get_session_analysis = AsyncMock(
return_value={
'found': True,
'session_id': 'sess-1',
}
)
app.monitoring_service.get_message_details = AsyncMock(
return_value={
'found': True,
'message_id': 'msg-1',
}
)
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
@@ -130,8 +127,7 @@ class TestMonitoringOverviewEndpoint:
async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get(
'/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -147,8 +143,7 @@ class TestMonitoringMessagesEndpoint:
async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -165,8 +160,7 @@ class TestMonitoringLLMCallsEndpoint:
async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -180,8 +174,7 @@ class TestMonitoringEmbeddingCallsEndpoint:
async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -195,8 +188,7 @@ class TestMonitoringSessionsEndpoint:
async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -210,8 +202,7 @@ class TestMonitoringErrorsEndpoint:
async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors."""
response = await quart_test_client.get(
'/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -225,8 +216,7 @@ class TestMonitoringAllDataEndpoint:
async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get(
'/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -242,8 +232,7 @@ class TestMonitoringDetailsEndpoints:
async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -252,8 +241,7 @@ class TestMonitoringDetailsEndpoints:
async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -267,8 +255,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -277,8 +264,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -292,8 +278,7 @@ class TestMonitoringExportEndpoint:
async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -303,8 +288,7 @@ class TestMonitoringExportEndpoint:
async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -313,8 +297,7 @@ class TestMonitoringExportEndpoint:
async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -323,8 +306,7 @@ class TestMonitoringExportEndpoint:
async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200

View File

@@ -49,6 +49,7 @@ def mock_circular_import_chain():
):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,27 +75,23 @@ def fake_provider_app():
# Provider service
app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock(return_value=[
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
])
app.provider_service.get_provider = AsyncMock(return_value={
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
})
app.provider_service.get_providers = AsyncMock(
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}]
)
app.provider_service.get_provider = AsyncMock(
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
)
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
})
app.provider_service.get_provider_model_counts = AsyncMock(
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0}
)
# LLM model service
app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}])
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock()
@@ -133,8 +132,7 @@ class TestProviderEndpoints:
async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -157,8 +155,7 @@ class TestProviderEndpoints:
async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -177,7 +174,7 @@ class TestProviderEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}
json={'name': 'New Provider', 'requester': 'chatcmpl'},
)
assert response.status_code == 200
@@ -194,7 +191,7 @@ class TestProviderEndpoints:
response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}
json={'name': 'Updated Provider'},
)
assert response.status_code == 200
@@ -205,8 +202,7 @@ class TestProviderEndpoints:
async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -215,8 +211,7 @@ class TestProviderEndpoints:
async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,8 +232,7 @@ class TestModelEndpoints:
async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -250,8 +244,7 @@ class TestModelEndpoints:
async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -264,7 +257,7 @@ class TestModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -276,8 +269,7 @@ class TestModelEndpoints:
async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -291,8 +283,7 @@ class TestEmbeddingModelEndpoints:
async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -306,7 +297,7 @@ class TestEmbeddingModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -323,8 +314,7 @@ class TestRerankModelEndpoints:
async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -338,7 +328,7 @@ class TestRerankModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200

View File

@@ -20,6 +20,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -69,6 +70,7 @@ def mock_circular_import_chain():
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
@@ -79,12 +81,14 @@ def fake_api_app():
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# API-specific services
app.user_service = Mock()
@@ -118,6 +122,7 @@ def fake_api_app():
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls):
"""
@@ -135,6 +140,7 @@ async def quart_test_client(fake_api_app, http_controller_cls):
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@@ -222,8 +228,7 @@ class TestProtectedEndpoints:
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@@ -254,10 +259,7 @@ class TestInvalidPayload:
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'})
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)

View File

@@ -2,4 +2,4 @@
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""
"""

View File

@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
db_file = tmp_path / 'test_migrations.db'
return f'sqlite+aiosqlite:///{db_file}'
@pytest.fixture
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestSQLiteMigrationFreshDatabase:
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_migrations_fresh.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
assert rev is not None, 'Expected a revision on fresh DB'
await fresh_engine.dispose()
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_empty_migrations.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
# Verify specific behavior - one of two outcomes is expected
if actual_result is not None:
# Migration succeeded - verify revision exists
assert actual_result is not None, "Revision should exist after successful migration"
assert actual_result is not None, 'Revision should exist after successful migration'
else:
# Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables
assert actual_error is not None, "Error should be captured if migration failed"
assert actual_error is not None, 'Error should be captured if migration failed'
# Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios
acceptable_errors = [
'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors
'CommandError', # Alembic command errors
]
assert error_type in acceptable_errors, (
f"Unexpected error type: {error_type}. "
f"This may indicate a regression in migration behavior. "
f"Error: {actual_error}"
f'Unexpected error type: {error_type}. '
f'This may indicate a regression in migration behavior. '
f'Error: {actual_error}'
)
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'

View File

@@ -34,14 +34,14 @@ def postgres_url():
"""Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL')
if not url:
pytest.skip("TEST_POSTGRES_URL not set")
pytest.skip('TEST_POSTGRES_URL not set')
return url
@pytest.fixture
async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT')
yield engine
await engine.dispose()
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -83,9 +83,7 @@ class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp baseline on existing tables sets correct revision.
@@ -106,9 +104,7 @@ class TestPostgreSQLMigrationBaseline:
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp on empty database (no tables) still sets revision.
@@ -125,9 +121,7 @@ class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Upgrade from baseline to head applies all migrations.
@@ -149,14 +143,12 @@ class TestPostgreSQLMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev is not None, "Expected a revision after upgrade"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration (0005 for current state)
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
@pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Running upgrade to head multiple times is idempotent.
@@ -180,7 +172,7 @@ class TestPostgreSQLMigrationUpgrade:
await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestPostgreSQLMigrationGetCurrent:
@@ -199,7 +191,7 @@ class TestPostgreSQLMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(postgres_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision(
@@ -214,4 +206,4 @@ class TestPostgreSQLMigrationGetCurrent:
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'

View File

@@ -2,4 +2,4 @@
Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner.
"""
"""

View File

@@ -26,6 +26,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -103,6 +104,7 @@ def mock_circular_import_chain():
# ============== FAKE RUNNER ==============
class FakeRunner:
"""Minimal fake runner class for pipeline integration tests.
@@ -117,12 +119,13 @@ class FakeRunner:
self.config = config or {}
self._provider = FakeProvider()
# Instance-level configuration set via class attribute
self._response_text = "fake response"
self._response_text = 'fake response'
self._raise_error = None
@classmethod
def returns(cls, text: str):
"""Create a runner class configured to return specific text."""
# We create a subclass with configured response
class ConfiguredRunner(cls):
name = cls.name
@@ -132,11 +135,13 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._response_text = text
return ConfiguredRunner
@classmethod
def raises(cls, error: Exception):
"""Create a runner class configured to raise an error."""
class ConfiguredRunner(cls):
name = cls.name
_response_text = None
@@ -145,6 +150,7 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._raise_error = error
return ConfiguredRunner
async def run(self, query):
@@ -161,6 +167,7 @@ class FakeRunner:
# ============== PIPELINE APP FIXTURE ==============
@pytest.fixture
def pipeline_app():
"""
@@ -187,6 +194,7 @@ def pipeline_app():
def __init__(self, name, messages):
self.name = name
self.messages = messages
def copy(self):
return MockPrompt(self.name, list(self.messages))
@@ -237,14 +245,17 @@ def fake_platform_adapter():
@pytest.fixture
def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner
# ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests."""
return {
@@ -273,6 +284,7 @@ def create_minimal_pipeline_config():
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name):
"""
Helper to handle the coroutine -> async_generator pattern.
@@ -296,6 +308,7 @@ async def collect_processor_results(processor, query, stage_name):
# ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain."""
@@ -337,7 +350,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
# Create query with adapter
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -365,7 +378,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
query = text_query("test message content")
query = text_query('test message content')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -396,11 +409,11 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Set fake runner that returns pong
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
@@ -414,6 +427,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -432,7 +446,7 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -445,6 +459,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -462,13 +477,13 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Create reply chain
reply_chain = text_chain("plugin response")
reply_chain = text_chain('plugin response')
# Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock()
@@ -479,6 +494,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -502,7 +518,7 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded'))
set_fake_runner(fake_runner)
# Create query with exception handling config
@@ -510,7 +526,7 @@ class TestRunnerExceptionFlow:
config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -523,6 +539,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -541,14 +558,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error'))
set_fake_runner(fake_runner)
# Create query with show-error mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -561,6 +578,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -578,14 +596,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception("Hidden error"))
fake_runner = FakeRunner().raises(Exception('Hidden error'))
set_fake_runner(fake_runner)
# Create query with hide mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -598,6 +616,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -623,7 +642,7 @@ class TestSendResponseBackStage:
adapter, platform = fake_platform_adapter
# Create query with response message
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -666,12 +685,12 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Set fake runner
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
config = create_minimal_pipeline_config()
query = text_query("ping")
query = text_query('ping')
query.adapter = adapter
query.pipeline_config = config
query.resp_messages = []
@@ -690,7 +709,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -711,6 +730,7 @@ class TestStageChainIntegration:
# Build resp_message_chain from resp_messages
from tests.factories.message import text_chain
for resp_msg in query.resp_messages:
if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content))
@@ -737,7 +757,7 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -754,7 +774,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -775,4 +795,4 @@ class TestStageChainIntegration:
assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages
assert len(query.resp_messages) == 0
assert len(query.resp_messages) == 0

View File

@@ -3,4 +3,4 @@ Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q
"""
"""

View File

@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
assert app.instance_config is not None
# Verify default config
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data["command"]["enable"] is True
assert app.instance_config.data['command']['prefix'] == ['/', '!']
assert app.instance_config.data['command']['enable'] is True
@pytest.mark.asyncio
async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response."""
provider = FakeProvider(default_response="test response")
provider = FakeProvider(default_response='test response')
# Create mock model with provider
model = fake_model(provider=provider)
# Create a simple query
query = text_query("hello")
query = text_query('hello')
# Simulate invoke
result = await provider.invoke_llm(
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
)
assert result is not None
assert result.role == "assistant"
assert result.content == "test response"
assert result.role == 'assistant'
assert result.content == 'test response'
@pytest.mark.asyncio
async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong()
model = fake_model(provider=provider)
query = text_query("ping")
query = text_query('ping')
result = await provider.invoke_llm(
query=query,
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
@pytest.mark.asyncio
async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(["Hello", " World"])
provider = FakeProvider().returns_streaming(['Hello', ' World'])
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
chunks = []
# invoke_llm_stream returns an async generator, don't await it
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].content == "Hello"
assert chunks[1].content == " World"
assert chunks[0].content == 'Hello'
assert chunks[1].content == ' World'
assert chunks[1].is_final is True
@pytest.mark.asyncio
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(TimeoutError, match="Provider timeout"):
with pytest.raises(TimeoutError, match='Provider timeout'):
await provider.invoke_llm(
query=query,
model=model,
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(Exception, match="Rate limit exceeded"):
with pytest.raises(Exception, match='Rate limit exceeded'):
await provider.invoke_llm(
query=query,
model=model,
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments."""
provider = FakeProvider()
model = fake_model(name="gpt-4", provider=provider)
query = text_query("hello")
model = fake_model(name='gpt-4', provider=provider)
query = text_query('hello')
await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "hello"}],
funcs=[{"name": "test_func"}],
extra_args={"temperature": 0.7},
messages=[{'role': 'user', 'content': 'hello'}],
funcs=[{'name': 'test_func'}],
extra_args={'temperature': 0.7},
)
captured = provider.get_captured_requests()
assert len(captured) == 1
assert captured[0]["model"] == "gpt-4"
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]["extra_args"] == {"temperature": 0.7}
assert captured[0]['model'] == 'gpt-4'
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}]
assert captured[0]['funcs'] == [{'name': 'test_func'}]
assert captured[0]['extra_args'] == {'temperature': 0.7}
@pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id="test-bot")
query = text_query("hello")
platform = FakePlatform(bot_account_id='test-bot')
query = text_query('hello')
# Simulate sending reply
from tests.factories.message import text_chain
reply_chain = text_chain("response text")
reply_chain = text_chain('response text')
event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False)
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
# Verify captured
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert outbound[0]["message"] == reply_chain
assert outbound[0]['type'] == 'reply'
assert outbound[0]['message'] == reply_chain
@pytest.mark.asyncio
async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_friend_message(
text="hello bot",
text='hello bot',
sender_id=12345,
nickname="TestUser",
nickname='TestUser',
)
assert event.type == "FriendMessage"
assert event.type == 'FriendMessage'
assert event.sender.id == 12345
assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == "hello bot"
assert event.sender.nickname == 'TestUser'
assert str(event.message_chain) == 'hello bot'
@pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_group_message(
text="hello everyone",
text='hello everyone',
sender_id=12345,
group_id=99999,
mention_bot=True,
)
assert event.type == "GroupMessage"
assert event.type == 'GroupMessage'
assert event.sender.id == 12345
assert event.group.id == 99999
@@ -220,54 +220,57 @@ class TestFakeMessageFlow:
async def test_query_factories_basic(self):
"""Test basic query factory functions."""
# Text query
q1 = text_query("hello world")
assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == "hello world"
q1 = text_query('hello world')
assert q1.launcher_type.value == 'person'
assert str(q1.message_chain) == 'hello world'
# Group query
from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
assert q2.launcher_type.value == "group"
q2 = group_text_query('hello group', group_id=88888)
assert q2.launcher_type.value == 'group'
assert q2.launcher_id == 88888
# Command query
from tests.factories import command_query
q3 = command_query("help", prefix="/")
assert str(q3.message_chain) == "/help"
q3 = command_query('help', prefix='/')
assert str(q3.message_chain) == '/help'
# Mention query
from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
assert q4.launcher_type.value == "group"
q4 = mention_query('hi', target='test-bot', group_id=77777)
assert q4.launcher_type.value == 'group'
@pytest.mark.asyncio
async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure()
query = text_query("hello")
query = text_query('hello')
from tests.factories.message import text_chain
with pytest.raises(Exception, match="Platform send failure"):
with pytest.raises(Exception, match='Platform send failure'):
await platform.reply_message(
query.message_event,
text_chain("response"),
text_chain('response'),
)
@pytest.mark.asyncio
async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id="bot-123")
platform = FakePlatform(bot_account_id='bot-123')
adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == "bot-123"
assert adapter.bot_account_id == 'bot-123'
assert adapter._fake_platform is platform
# Test reply_message is wired
from tests.factories.message import text_chain
query = text_query("test")
await adapter.reply_message(query.message_event, text_chain("response"))
query = text_query('test')
await adapter.reply_message(query.message_event, text_chain('response'))
# Verify platform captured it
assert len(platform.get_outbound_messages()) == 1
@@ -293,18 +296,18 @@ class TestMessageFlowIntegration:
Note: This does NOT run actual LangBot pipeline stages.
"""
# Setup
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
provider = fake_provider_pong()
model = fake_model(provider=provider)
# Create inbound message
query = text_query("ping")
query = text_query('ping')
# Simulate provider processing
response = await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "ping"}],
messages=[{'role': 'user', 'content': 'ping'}],
funcs=[],
extra_args={},
)
@@ -321,16 +324,16 @@ class TestMessageFlowIntegration:
# Verify platform captured outbound
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
assert outbound[0]['type'] == 'reply'
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(["Hello", " there"])
provider = FakeProvider().returns_streaming(['Hello', ' there'])
model = fake_model(provider=provider)
query = text_query("hi")
query = text_query('hi')
chunks = []
async for chunk in provider.invoke_llm_stream(
@@ -344,8 +347,8 @@ class TestMessageFlowIntegration:
# Verify streaming worked
assert len(chunks) == 2
full_content = "".join(c.content for c in chunks)
assert full_content == "Hello there"
full_content = ''.join(c.content for c in chunks)
assert full_content == 'Hello there'
# Verify platform supports streaming
assert await platform.is_stream_output_supported() is True
assert await platform.is_stream_output_supported() is True

View File

@@ -15,22 +15,12 @@ import pathlib
# Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py'
def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
tree = ast.parse(source)
@@ -40,27 +30,26 @@ def test_no_exec_call_in_system_controller():
if isinstance(node, ast.Call):
func = node.func
# Match bare exec() call
if isinstance(func, ast.Name) and func.id == "exec":
if isinstance(func, ast.Name) and func.id == 'exec':
exec_calls.append(node.lineno)
assert len(exec_calls) == 0, (
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().'
)
def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
assert "debug/exec" not in source, (
"The /debug/exec route still exists in system.py. "
"This endpoint allows arbitrary code execution and must be removed."
assert 'debug/exec' not in source, (
'The /debug/exec route still exists in system.py. '
'This endpoint allows arbitrary code execution and must be removed.'
)
if __name__ == "__main__":
if __name__ == '__main__':
test_no_exec_call_in_system_controller()
test_no_debug_exec_route()
print("All tests passed!")
print('All tests passed!')

View File

@@ -1 +1 @@
"""Unit tests for LangBot API HTTP service layer."""
"""Unit tests for LangBot API HTTP service layer."""

View File

@@ -13,4 +13,4 @@ Does NOT:
- Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application.
"""
"""

View File

@@ -132,9 +132,7 @@ class TestApiKeyServiceCreateApiKey:
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}]
assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key'

View File

@@ -303,13 +303,7 @@ class TestBotServiceCreateBot:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
@@ -318,9 +312,7 @@ class TestBotServiceCreateBot:
bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'})
service = BotService(ap)
@@ -352,6 +344,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -362,9 +355,7 @@ class TestBotServiceCreateBot:
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'})
service = BotService(ap)
@@ -397,6 +388,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -492,6 +484,7 @@ class TestBotServiceUpdateBot:
pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -582,10 +575,9 @@ class TestBotServiceListEventLogs:
# Mock runtime bot with logger
runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock(return_value=(
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
5
))
runtime_bot.logger.get_logs = AsyncMock(
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5)
)
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
@@ -646,11 +638,7 @@ class TestBotServiceSendMessage:
service = BotService(ap)
# Execute with valid message chain format
message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]}
# Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:

View File

@@ -6,6 +6,7 @@ Tests cover:
- Knowledge engine discovery
- File operations
"""
from __future__ import annotations
import pytest
@@ -52,9 +53,7 @@ class TestGetKnowledgeBases:
"""Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}])
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
@@ -83,9 +82,7 @@ class TestGetKnowledgeBase:
"""Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'})
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1')
@@ -153,9 +150,7 @@ class TestCreateKnowledgeBase:
service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'})
# Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
@@ -170,20 +165,21 @@ class TestUpdateKnowledgeBase:
"""Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'})
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields
await service.update_knowledge_base('kb1', {
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
})
await service.update_knowledge_base(
'kb1',
{
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
},
)
# Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args
@@ -288,9 +284,7 @@ class TestListKnowledgeEngines:
"""Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
@@ -386,12 +380,10 @@ class TestGetEngineSchemas:
"""Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert result == {}
mock_app.logger.warning.assert_called_once()
mock_app.logger.warning.assert_called_once()

View File

@@ -174,9 +174,7 @@ class TestMaintenanceServiceGetStorageAnalysis:
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
@@ -292,12 +290,8 @@ class TestMaintenanceServiceGetStorageAnalysis:
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[
{'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}])
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}])
# Execute
result = await service.get_storage_analysis()
@@ -367,6 +361,7 @@ class TestMaintenanceServiceBinaryStorageStats:
size_result = _create_mock_result(scalar_value=5000)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +391,7 @@ class TestMaintenanceServiceBinaryStorageStats:
count_result = _create_mock_result(scalar_value=5)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -821,4 +817,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included
assert 'path' in result[0]
assert 'path' in result[0]

View File

@@ -186,13 +186,7 @@ class TestMCPServiceCreateMCPServer:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}}
ap.plugin_connector = SimpleNamespace()
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
@@ -252,6 +246,7 @@ class TestMCPServiceCreateMCPServer:
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -361,6 +356,7 @@ class TestMCPServiceUpdateMCPServer:
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -394,6 +390,7 @@ class TestMCPServiceUpdateMCPServer:
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -432,6 +429,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock for: first select -> update -> second select (for updated server)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -465,6 +463,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock execute for select and update
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -499,6 +498,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Server to Delete')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -530,6 +530,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Not in Sessions')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -559,6 +560,7 @@ class TestMCPServiceDeleteMCPServer:
# No server found
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -596,9 +598,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=123)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123))
service = MCPService(ap)
@@ -634,9 +634,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=456)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456))
service = MCPService(ap)
@@ -645,4 +643,4 @@ class TestMCPServiceTestMCPServer:
# Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456
assert task_id == 456

View File

@@ -167,6 +167,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([])
call_count = 0
async def mock_execute(query):
return mock_result if call_count == 0 else mock_provider_result
@@ -200,6 +201,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -239,6 +241,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -279,6 +282,7 @@ class TestLLMModelsServiceGetLLMModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -337,9 +341,7 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'model-1', 'name': 'Model 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'})
service = LLMModelsService(ap)
@@ -371,12 +373,14 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
})
model_uuid = await service.create_llm_model(
{
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -400,13 +404,16 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}, preserve_uuid=True)
model_uuid = await service.create_llm_model(
{
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
},
preserve_uuid=True,
)
# Verify
assert model_uuid == 'preserved-uuid'
@@ -459,12 +466,14 @@ class TestLLMModelsServiceCreateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_llm_model({
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
})
await service.create_llm_model(
{
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
}
)
async def test_create_llm_model_with_provider_data(self):
"""Creates provider when provider data provided."""
@@ -490,16 +499,18 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute - with provider data (no UUID)
result_uuid = await service.create_llm_model({
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
})
result_uuid = await service.create_llm_model(
{
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
}
)
# Verify - provider_service was called and UUID generated
ap.provider_service.find_or_create_provider.assert_called_once()
@@ -525,11 +536,14 @@ class TestLLMModelsServiceUpdateLLMModel:
service = LLMModelsService(ap)
# Execute
await service.update_llm_model('existing-uuid', {
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
})
await service.update_llm_model(
'existing-uuid',
{
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
},
)
# Verify - remove and load called
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
@@ -549,10 +563,13 @@ class TestLLMModelsServiceUpdateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.update_llm_model('model-uuid', {
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
})
await service.update_llm_model(
'model-uuid',
{
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
},
)
async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args."""
@@ -618,9 +635,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'})
service = EmbeddingModelsService(ap)
@@ -643,6 +658,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -683,6 +699,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -742,11 +759,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
service = EmbeddingModelsService(ap)
# Execute
model_uuid = await service.create_embedding_model({
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_embedding_model(
{
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -767,11 +786,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_embedding_model({
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_embedding_model(
{
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
@@ -829,6 +850,7 @@ class TestRerankModelsServiceGetRerankModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -869,6 +891,7 @@ class TestRerankModelsServiceGetRerankModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -928,11 +951,13 @@ class TestRerankModelsServiceCreateRerankModel:
service = RerankModelsService(ap)
# Execute
model_uuid = await service.create_rerank_model({
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_rerank_model(
{
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -952,11 +977,13 @@ class TestRerankModelsServiceCreateRerankModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_rerank_model({
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_rerank_model(
{
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestRerankModelsServiceDeleteRerankModel:
@@ -995,9 +1022,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'})
service = EmbeddingModelsService(ap)
@@ -1022,9 +1047,7 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'})
service = RerankModelsService(ap)
@@ -1042,14 +1065,10 @@ class TestValidateProviderSupports:
def _make_ap(requester_name: str, support_type):
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
manifest = SimpleNamespace(spec={'support_type': support_type})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester=requester_name)
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest
if name == requester_name
else None,
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None,
)
return SimpleNamespace(model_mgr=model_mgr)
@@ -1066,9 +1085,7 @@ class TestValidateProviderSupports:
async def test_allows_when_support_type_missing(self):
# Manifest without support_type must not block (backward compatible)
manifest = SimpleNamespace(spec={})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester='legacy')
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy'))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest,

View File

@@ -215,13 +215,7 @@ class TestPipelineServiceCreatePipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -229,9 +223,7 @@ class TestPipelineServiceCreatePipeline:
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'})
service = PipelineService(ap)
@@ -258,14 +250,14 @@ class TestPipelineServiceCreatePipeline:
# Mock persistence for insert
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
# Mock the file read for default config - patch at the utils module level
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
# Verify
@@ -286,7 +278,9 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
service.get_pipeline = AsyncMock(
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
)
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
@@ -296,7 +290,9 @@ class TestPipelineServiceCreatePipeline:
# Mock the file read
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
# Verify - execute was called
@@ -316,10 +312,12 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
})
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
}
)
insert_params = []
@@ -339,7 +337,9 @@ class TestPipelineServiceCreatePipeline:
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'New Pipeline'})
assert len(insert_params) == 1
@@ -353,6 +353,7 @@ class TestPipelineServiceCreatePipeline:
class _MockResultWithBots:
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
def __init__(self, bots_list):
self._bots_list = bots_list
@@ -428,6 +429,7 @@ class TestPipelineServiceUpdatePipeline:
# 1. UPDATE (line 125) - returns Mock (no result needed)
# 2. SELECT bots (line 136) - returns bot_result with .all()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -528,13 +530,7 @@ class TestPipelineServiceCopyPipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -542,10 +538,12 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
# Mock get_pipelines to return 2 pipelines
service.get_pipelines = AsyncMock(return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
])
service.get_pipelines = AsyncMock(
return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
]
)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'):
@@ -642,9 +640,7 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'copy-uuid', 'is_default': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False})
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
@@ -681,11 +677,10 @@ class TestPipelineServiceUpdatePipelineExtensions:
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
)
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []})
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -700,7 +695,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -711,7 +706,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -738,6 +733,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
original_pipeline = _create_mock_pipeline()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -752,7 +748,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_mcp_servers': False,
'mcp_servers': ['mcp-server-1'],
}
},
}
)
@@ -794,6 +790,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1

View File

@@ -245,12 +245,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
provider_uuid = await service.create_provider({
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
provider_uuid = await service.create_provider(
{
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - UUID is generated
assert provider_uuid is not None
@@ -274,12 +276,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
result_uuid = await service.create_provider({
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
result_uuid = await service.create_provider(
{
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - provider added to runtime dict and UUID generated
ap.model_mgr.load_provider.assert_called_once()
@@ -302,10 +306,13 @@ class TestModelProviderServiceUpdateProvider:
service = ModelProviderService(ap)
# Execute
await service.update_provider('existing-uuid', {
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
})
await service.update_provider(
'existing-uuid',
{
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
},
)
# Verify - reload called
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
@@ -364,6 +371,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=None)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +404,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -454,6 +463,7 @@ class TestModelProviderServiceGetProviderModelCounts:
rerank_result.scalar = Mock(return_value=1)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -637,9 +647,7 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
await service.update_space_model_provider_api_keys('space-api-key')
# Verify - update and reload called for Space provider UUID
ap.model_mgr.reload_provider.assert_called_once_with(
'00000000-0000-0000-0000-000000000000'
)
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000')
class TestModelProviderServiceScanProviderModels:
@@ -795,9 +803,7 @@ class TestModelProviderServiceScanProviderModels:
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
runtime_provider.requester.scan_models = AsyncMock(
side_effect=NotImplementedError('scan not supported')
)
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported'))
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
service = ModelProviderService(ap)
@@ -848,9 +854,7 @@ class TestModelProviderServiceScanProviderModels:
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing LLM model
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
return_value=[{'name': 'Existing Model'}]
)
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}])
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
@@ -863,4 +867,4 @@ class TestModelProviderServiceScanProviderModels:
assert existing_model['already_added'] is True
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
assert new_model['already_added'] is False
assert new_model['already_added'] is False

View File

@@ -393,14 +393,16 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -429,10 +431,12 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 1,
'msg': 'Invalid refresh token',
})
mock_response.json = AsyncMock(
return_value={
'code': 1,
'msg': 'Invalid refresh token',
}
)
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
@@ -489,14 +493,16 @@ class TestSpaceServiceExchangeOAuthCode:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -555,13 +561,15 @@ class TestSpaceServiceGetUserInfoRaw:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -669,27 +677,29 @@ class TestSpaceServiceGetModels:
# Mock HTTP response with proper model data matching SpaceModel schema
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -775,4 +785,4 @@ class TestSpaceServiceCreditsCache:
# Verify - cache updated
assert result == 500
assert 'test@example.com' in service._credits_cache
assert service._credits_cache['test@example.com'][0] == 500
assert service._credits_cache['test@example.com'][0] == 500

View File

@@ -495,6 +495,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -565,6 +566,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -605,4 +607,4 @@ class TestUserServiceCreateUserLock:
# Verify lock exists
assert hasattr(service, '_create_user_lock')
assert service._create_user_lock is not None
assert service._create_user_lock is not None

View File

@@ -132,6 +132,7 @@ class TestWebhookServiceCreateWebhook:
# execute_async returns different results
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -181,6 +182,7 @@ class TestWebhookServiceCreateWebhook:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -217,6 +219,7 @@ class TestWebhookServiceCreateWebhook:
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -225,9 +228,7 @@ class TestWebhookServiceCreateWebhook:
return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'id': 1, 'enabled': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False})
service = WebhookService(ap)
@@ -503,4 +504,4 @@ class TestWebhookServiceGetEnabledWebhooks:
result = await service.get_enabled_webhooks()
# Verify - should be empty (SQL would filter disabled)
assert result == []
assert result == []

View File

@@ -407,7 +407,9 @@ def test_box_service_forced_template_ignores_pipeline_config():
launcher_type='person',
launcher_id='test_user',
sender_id='test_user',
pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
pipeline_config={
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
},
)
assert service.resolve_box_session_id(query) == 'global'
@@ -1527,9 +1529,7 @@ class TestBuildSkillExtraMounts:
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
]
# No skill is dropped, so no "missing" warning should be logged.
assert not any(
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
)
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list)
def test_skips_skill_with_empty_package_root(self):
logger = Mock()

View File

@@ -1 +1 @@
# Unit tests for command module
# Unit tests for command module

View File

@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
# Should yield CommandNotFoundError (no such command registered)
assert len(results) == 1
assert results[0].error is not None
assert results[0].error is not None

View File

@@ -197,6 +197,7 @@ class TestCommandOperatorBase:
op = TestOperator(None)
# Should not raise
import asyncio
asyncio.get_event_loop().run_until_complete(op.initialize())
def test_execute_is_abstract(self):
@@ -299,4 +300,4 @@ class TestMultipleOperators:
yield None
assert AdminOperator.lowest_privilege == 2
assert SubOperator.lowest_privilege == 1
assert SubOperator.lowest_privilege == 1

View File

@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
@pytest.mark.asyncio
async def test_valid_yaml_loads(self, tmp_path):
"""Valid YAML config should load correctly."""
config_file = tmp_path / "test_config.yaml"
config_file = tmp_path / 'test_config.yaml'
# Write valid YAML
config_file.write_text("""
@@ -51,7 +51,7 @@ settings:
@pytest.mark.asyncio
async def test_invalid_yaml_raises_error(self, tmp_path):
"""Invalid YAML should raise clear error."""
config_file = tmp_path / "invalid.yaml"
config_file = tmp_path / 'invalid.yaml'
# Write invalid YAML (unclosed bracket)
config_file.write_text("""
@@ -67,13 +67,13 @@ settings:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await yaml_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_config_creates_from_template(self, tmp_path):
"""Missing config file should be created from template."""
config_file = tmp_path / "new_config.yaml"
config_file = tmp_path / 'new_config.yaml'
# File doesn't exist yet
assert not config_file.exists()
@@ -92,7 +92,7 @@ settings:
@pytest.mark.asyncio
async def test_template_completion(self, tmp_path):
"""Config should be completed with template defaults."""
config_file = tmp_path / "partial.yaml"
config_file = tmp_path / 'partial.yaml'
# Write partial config missing some template keys
config_file.write_text("""
@@ -115,7 +115,7 @@ name: custom_name
@pytest.mark.asyncio
async def test_yaml_save(self, tmp_path):
"""YAML config can be saved."""
config_file = tmp_path / "save_test.yaml"
config_file = tmp_path / 'save_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -131,7 +131,7 @@ name: custom_name
def test_yaml_save_sync(self, tmp_path):
"""YAML config can be saved synchronously."""
config_file = tmp_path / "sync_save.yaml"
config_file = tmp_path / 'sync_save.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -151,14 +151,18 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_valid_json_loads(self, tmp_path):
"""Valid JSON config should load correctly."""
config_file = tmp_path / "test_config.json"
config_file = tmp_path / 'test_config.json'
# Write valid JSON
config_file.write_text(json.dumps({
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}))
config_file.write_text(
json.dumps(
{
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}
)
)
json_file = JSONConfigFile(
str(config_file),
@@ -174,7 +178,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_invalid_json_raises_error(self, tmp_path):
"""Invalid JSON should raise clear error."""
config_file = tmp_path / "invalid.json"
config_file = tmp_path / 'invalid.json'
# Write invalid JSON (missing closing brace)
config_file.write_text('{"name": "test", "unclosed": ')
@@ -184,13 +188,13 @@ class TestJSONConfigFile:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await json_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_json_creates_from_template(self, tmp_path):
"""Missing JSON file should be created from template."""
config_file = tmp_path / "new_config.json"
config_file = tmp_path / 'new_config.json'
json_file = JSONConfigFile(
str(config_file),
@@ -205,7 +209,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_json_save(self, tmp_path):
"""JSON config can be saved."""
config_file = tmp_path / "save_test.json"
config_file = tmp_path / 'save_test.json'
json_file = JSONConfigFile(
str(config_file),
@@ -226,7 +230,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_load(self, tmp_path):
"""ConfigManager loads config correctly."""
config_file = tmp_path / "manager_test.yaml"
config_file = tmp_path / 'manager_test.yaml'
config_file.write_text('name: managed_app\nversion: "1.0"\n')
yaml_file = YAMLConfigFile(
@@ -243,7 +247,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_dump(self, tmp_path):
"""ConfigManager can dump config."""
config_file = tmp_path / "dump_test.yaml"
config_file = tmp_path / 'dump_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -260,7 +264,7 @@ class TestConfigManager:
def test_config_manager_dump_sync(self, tmp_path):
"""ConfigManager can dump config synchronously."""
config_file = tmp_path / "sync_dump.yaml"
config_file = tmp_path / 'sync_dump.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -280,7 +284,7 @@ class TestConfigExists:
def test_yaml_exists_true(self, tmp_path):
"""exists() returns True for existing file."""
config_file = tmp_path / "exists.yaml"
config_file = tmp_path / 'exists.yaml'
config_file.write_text('name: test')
yaml_file = YAMLConfigFile(str(config_file), template_data={})
@@ -288,14 +292,14 @@ class TestConfigExists:
def test_yaml_exists_false(self, tmp_path):
"""exists() returns False for missing file."""
config_file = tmp_path / "missing.yaml"
config_file = tmp_path / 'missing.yaml'
yaml_file = YAMLConfigFile(str(config_file), template_data={})
assert yaml_file.exists() is False
def test_json_exists_true(self, tmp_path):
"""exists() returns True for existing JSON file."""
config_file = tmp_path / "exists.json"
config_file = tmp_path / 'exists.json'
config_file.write_text('{}')
json_file = JSONConfigFile(str(config_file), template_data={})
@@ -303,7 +307,7 @@ class TestConfigExists:
def test_json_exists_false(self, tmp_path):
"""exists() returns False for missing JSON file."""
config_file = tmp_path / "missing.json"
config_file = tmp_path / 'missing.json'
json_file = JSONConfigFile(str(config_file), template_data={})
assert json_file.exists() is False
assert json_file.exists() is False

View File

@@ -1 +1 @@
"""Core module unit tests."""
"""Core module unit tests."""

View File

@@ -4,6 +4,7 @@ Tests cover:
- _get_positive_int_config() validation
- _get_positive_float_config() validation
"""
from __future__ import annotations
from unittest.mock import Mock
@@ -188,4 +189,4 @@ class TestGetPositiveFloatConfig:
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
assert result == 1.5
mock_logger.warning.assert_called_once()
mock_logger.warning.assert_called_once()

View File

@@ -27,6 +27,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert result == []
@@ -46,6 +47,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert 'requests' in result
@@ -61,6 +63,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
# Should include all required_deps keys
@@ -107,6 +110,7 @@ class TestPrecheckPluginDeps:
with patch('os.path.exists', return_value=False):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_not_called()
@@ -129,6 +133,7 @@ class TestPrecheckPluginDeps:
with patch('os.listdir', side_effect=mock_listdir):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])

View File

@@ -7,6 +7,7 @@ Tests cover:
- Dict type skipping
- Missing key creation
"""
from __future__ import annotations
import os
@@ -248,15 +249,8 @@ class TestApplyEnvOverridesToConfig:
"""Test multiple env vars applied in order."""
load_config = get_load_config_module()
cfg = {
'system': {'name': 'default', 'enable': True},
'concurrency': {'pipeline': 5}
}
env = {
'SYSTEM__NAME': 'custom',
'SYSTEM__ENABLE': 'false',
'CONCURRENCY__PIPELINE': '10'
}
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}}
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'}
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
@@ -287,4 +281,4 @@ class TestApplyEnvOverridesToConfig:
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'

View File

@@ -175,4 +175,4 @@ class TestPreregisteredStages:
pass
for key in preregistered_stages:
assert isinstance(key, str)
assert isinstance(key, str)

View File

@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import_isolation to break circular import chains.
"""
from __future__ import annotations
import pytest
@@ -19,15 +20,17 @@ from typing import Generator
class MockLifecycleControlScopeEnum:
"""Mock enum value for LifecycleControlScope with .value attribute."""
def __init__(self, value: str):
self.value = value
def __repr__(self):
return f"LifecycleControlScope.{self.value.upper()}"
return f'LifecycleControlScope.{self.value.upper()}'
class MockLifecycleControlScope:
"""Mock enum for LifecycleControlScope."""
APPLICATION = MockLifecycleControlScopeEnum('application')
PLATFORM = MockLifecycleControlScopeEnum('platform')
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
@@ -40,17 +43,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
# Mock modules that cause circular imports
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
mock_app = MagicMock()
mock_importutil = MagicMock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
mock_http_controller = MagicMock()
mock_rag_mgr = MagicMock()
mocks = {
'langbot.pkg.core.entities': mock_entities,
'langbot.pkg.core.app': mock_app,
@@ -58,26 +61,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
'langbot.pkg.utils.importutil': mock_importutil,
}
# Save original state
saved = {}
for name in mocks:
if name in sys.modules:
saved[name] = sys.modules[name]
# Clear taskmgr to force re-import
taskmgr_name = 'langbot.pkg.core.taskmgr'
if taskmgr_name in sys.modules:
saved[taskmgr_name] = sys.modules[taskmgr_name]
try:
# Apply mocks
for name, module in mocks.items():
sys.modules[name] = module
# Clear taskmgr
sys.modules.pop(taskmgr_name, None)
yield
finally:
# Restore
@@ -86,7 +89,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
sys.modules[name] = saved[name]
else:
sys.modules.pop(name, None)
if taskmgr_name in saved:
sys.modules[taskmgr_name] = saved[taskmgr_name]
else:
@@ -97,6 +100,7 @@ def get_taskmgr_classes():
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
return TaskContext, TaskWrapper, AsyncTaskManager
@@ -194,9 +198,10 @@ class TestTaskContext:
"""Test TaskContext.placeholder() returns singleton."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext
# Reset global placeholder
import langbot.pkg.core.taskmgr as taskmgr_module
taskmgr_module.placeholder_context = None
ctx1 = TaskContext.placeholder()
@@ -269,7 +274,8 @@ class TestTaskWrapper:
return 'result'
wrapper = TaskWrapper(
mock_app, immediate_coro(),
mock_app,
immediate_coro(),
name='test_task',
label='Test Task',
)
@@ -414,7 +420,7 @@ class TestAsyncTaskManager:
async def test_cancel_by_scope(self):
"""Test cancel_by_scope cancels matching tasks."""
_, _, AsyncTaskManager = get_taskmgr_classes()
mock_app = create_mock_app()
manager = AsyncTaskManager(mock_app)
@@ -422,16 +428,10 @@ class TestAsyncTaskManager:
await asyncio.sleep(10)
# Create task with APPLICATION scope
w1 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.APPLICATION]
)
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION])
# Create task with different scope
w2 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.PIPELINE]
)
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE])
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)

View File

@@ -15,68 +15,68 @@ class TestI18nString:
def test_create_with_english_only(self):
"""Create I18nString with only English."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
assert i18n.en_US == "Hello"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans is None
def test_create_with_multiple_languages(self):
"""Create I18nString with multiple languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
)
assert i18n.en_US == "Hello"
assert i18n.zh_Hans == "你好"
assert i18n.zh_Hant == "你好"
assert i18n.ja_JP == "こんにちは"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans == '你好'
assert i18n.zh_Hant == '你好'
assert i18n.ja_JP == 'こんにちは'
def test_to_dict_with_english_only(self):
"""to_dict returns only non-None fields."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
result = i18n.to_dict()
assert result == {"en_US": "Hello"}
assert result == {'en_US': 'Hello'}
def test_to_dict_with_multiple_languages(self):
"""to_dict returns all non-None fields."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
en_US='Hello',
zh_Hans='你好',
)
result = i18n.to_dict()
assert result == {"en_US": "Hello", "zh_Hans": "你好"}
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'}
def test_to_dict_excludes_none(self):
"""to_dict excludes None values."""
i18n = I18nString(
en_US="Hello",
en_US='Hello',
zh_Hans=None,
ja_JP="こんにちは",
ja_JP='こんにちは',
)
result = i18n.to_dict()
assert "zh_Hans" not in result
assert "en_US" in result
assert "ja_JP" in result
assert 'zh_Hans' not in result
assert 'en_US' in result
assert 'ja_JP' in result
def test_to_dict_all_languages(self):
"""to_dict with all supported languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
th_TH="สวัสดี",
vi_VN="Xin chào",
es_ES="Hola",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
th_TH='สวัสดี',
vi_VN='Xin chào',
es_ES='Hola',
)
result = i18n.to_dict()
@@ -92,30 +92,30 @@ class TestMetadata:
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test Component"),
name='test-component',
label=I18nString(en_US='Test Component'),
)
assert metadata.name == "test-component"
assert metadata.label.en_US == "Test Component"
assert metadata.name == 'test-component'
assert metadata.label.en_US == 'Test Component'
def test_create_with_all_fields(self):
"""Create Metadata with all optional fields."""
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test"),
description=I18nString(en_US="A test component"),
version="1.0.0",
icon="test-icon",
author="Test Author",
repository="https://github.com/test/repo",
name='test-component',
label=I18nString(en_US='Test'),
description=I18nString(en_US='A test component'),
version='1.0.0',
icon='test-icon',
author='Test Author',
repository='https://github.com/test/repo',
)
assert metadata.version == "1.0.0"
assert metadata.icon == "test-icon"
assert metadata.author == "Test Author"
assert metadata.version == '1.0.0'
assert metadata.icon == 'test-icon'
assert metadata.author == 'Test Author'
class TestComponentManifest:

View File

@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import isolation to break circular import chains.
"""
from __future__ import annotations
import sys
@@ -86,6 +87,7 @@ def get_database_module():
"""Get database module with import isolation."""
with isolated_database_import():
from langbot.pkg.persistence import database
return database
@@ -198,4 +200,4 @@ class TestManagerClassDecorator:
# Create instance to test method (with mock app)
mock_app = Mock()
instance = ManagerWithMethods(mock_app)
assert instance.custom_method() == 'test_value'
assert instance.custom_method() == 'test_value'

View File

@@ -4,6 +4,7 @@ Tests cover:
- execute_async() with mock database
- get_db_engine() with mock database manager
"""
from __future__ import annotations
import pytest
@@ -85,7 +86,7 @@ class TestExecuteAsync:
mock_db.get_engine = Mock(return_value=mock_engine)
mgr.db = mock_db
result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
result = await mgr.execute_async(sqlalchemy.text('SELECT 1'))
# Verify result is the same object returned by execute
assert result is mock_result
@@ -152,4 +153,4 @@ class TestSerializeModelEdgeCases:
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
# Result should be empty dict when all columns masked
assert result == {}
assert result == {}

View File

@@ -5,6 +5,7 @@ Tests cover:
- datetime conversion to isoformat
- masked_columns exclusion
"""
from __future__ import annotations
import datetime

View File

@@ -49,7 +49,7 @@ class TestPendingMessage:
"""PendingMessage should be created with correct fields."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -88,7 +88,7 @@ class TestSessionBuffer:
"""SessionBuffer should accept initial messages."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("hello")
chain2 = text_chain("world")
chain1 = text_chain('hello')
chain2 = text_chain('world')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
# Should contain both messages with separator
merged_str = str(merged.message_chain)
assert "hello" in merged_str
assert "world" in merged_str
assert 'hello' in merged_str
assert 'world' in merged_str
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("first")
chain2 = text_chain("second")
chain1 = text_chain('first')
chain2 = text_chain('second')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()

View File

@@ -15,6 +15,7 @@ from tests.factories import FakeApp
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -36,9 +37,11 @@ def mock_circular_import_chain():
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
@@ -70,9 +73,12 @@ def mock_event_ctx():
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
@@ -87,6 +93,7 @@ def get_chat_handler():
global _chat_handler_module
if _chat_handler_module is None:
from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module
@@ -96,12 +103,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class."""
@@ -188,9 +197,11 @@ class TestChatMessageHandlerReal:
class QuickRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='ok')
@@ -222,9 +233,11 @@ class TestChatMessageHandlerReal:
class SingleRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='response')
@@ -262,9 +275,11 @@ class TestChatHandlerStreaming:
class StreamRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True)
@@ -303,14 +318,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class FailingRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('API error')
yield
@@ -346,14 +366,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class ErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('Custom error')
yield
@@ -386,14 +411,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class HideErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise RuntimeError('hidden')
yield
@@ -433,4 +463,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result

View File

@@ -67,7 +67,11 @@ def make_pipeline_config(**overrides):
for key, value in overrides.items():
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
if (
sub_key in base_config[key]
and isinstance(base_config[key][sub_key], dict)
and isinstance(sub_value, dict)
):
base_config[key][sub_key].update(sub_value)
else:
base_config[key][sub_key] = sub_value
@@ -141,7 +145,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -163,7 +167,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
# Empty message chain
query = text_query("")
query = text_query('')
query.message_chain = platform_message.MessageChain([])
query.pipeline_config = pipeline_config
@@ -185,7 +189,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query(" ") # Only whitespace
query = text_query(' ') # Only whitespace
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -234,7 +238,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -266,7 +270,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -294,7 +298,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("http://example.com")
query = text_query('http://example.com')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -322,7 +326,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("normal message")
query = text_query('normal message')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -343,7 +347,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -368,12 +372,10 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Add a response message
query.resp_messages = [
provider_message.Message(role='assistant', content='Hello back!')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -398,11 +400,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='Response')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -422,7 +422,7 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-string content - use model_construct to bypass validation
# The actual content type could be a list of ContentElement objects
@@ -450,11 +450,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='')
]
query.resp_messages = [provider_message.Message(role='assistant', content='')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
@@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect:
await stage.initialize(pipeline_config)
query = text_query("normal message without prefix")
query = text_query('normal message without prefix')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')

View File

@@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -56,6 +57,7 @@ def mock_event_ctx():
@pytest.fixture
def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute(
text: str | None = 'ok',
error: str | None = None,
@@ -71,7 +73,9 @@ def mock_execute_factory():
ret.image_base64 = image_base64
ret.file_url = file_url
yield ret
return mock_execute
return _create_execute
@@ -86,6 +90,7 @@ def get_command_handler():
global _command_handler_module
if _command_handler_module is None:
from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module
@@ -95,12 +100,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal:
"""Tests for real CommandHandler class."""
@@ -127,6 +134,7 @@ class TestCommandHandlerReal:
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
executed_commands = []
async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text)
ret = Mock()
@@ -334,8 +342,7 @@ class TestCommandHandlerReal:
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:',
image_url='https://example.com/image.png'
text='Here is the image:', image_url='https://example.com/image.png'
)
handler = command.CommandHandler(fake_app)
@@ -393,4 +400,4 @@ class TestCommandHandlerHelper:
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result

View File

@@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="very long response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="short response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-Plain component (Image)
query.resp_message_chain = [
platform_message.MessageChain([
platform_message.Plain(text="short"),
platform_message.Image(url="https://example.com/img.png")
])
platform_message.MessageChain(
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
)
]
result = await stage.process(query, 'LongTextProcessStage')
@@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess:
stage = longtext.LongTextProcessStage(app)
stage.strategy_impl = AsyncMock()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
query.resp_message_chain = []
@@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess:
assert result.new_query is query
stage.strategy_impl.process.assert_not_called()
class TestForwardStrategy:
"""Tests for ForwardComponentStrategy."""
@@ -260,7 +256,7 @@ class TestForwardStrategy:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Create a mock adapter with bot_account_id
mock_adapter = Mock()
@@ -268,10 +264,8 @@ class TestForwardStrategy:
query.adapter = mock_adapter
# Long text exceeding threshold
long_text = "This is a very long response that exceeds the threshold"
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=long_text)])
]
long_text = 'This is a very long response that exceeds the threshold'
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -297,13 +291,13 @@ class TestForwardStrategy:
await strat.initialize()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config()
mock_adapter = Mock()
mock_adapter.bot_account_id = '12345'
query.adapter = mock_adapter
components = await strat.process("test message", query)
components = await strat.process('test message', query)
assert len(components) == 1
assert isinstance(components[0], platform_message.Forward)
@@ -326,14 +320,12 @@ class TestLongTextThreshold:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Text below threshold
short_text = "x" * (threshold - 1)
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=short_text)])
]
short_text = 'x' * (threshold - 1)
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
result = await stage.process(query, 'LongTextProcessStage')

View File

@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
# Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = []
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
trun = trun_cls(app)
break
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),

View File

@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello world")
query = text_query('hello world')
result = await stage.process(query, 'PreProcessor')
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("test message")
query = text_query('test message')
result = await stage.process(query, 'PreProcessor')
@@ -194,13 +194,16 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app)
# Image query with base64
query = image_query(text="look at this", url=None)
query = image_query(text='look at this', url=None)
# Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([
platform_message.Plain(text="look at this"),
platform_message.Image(base64="data:image/png;base64,abc123"),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text='look at this'),
platform_message.Image(base64='data:image/png;base64,abc123'),
]
)
query.message_chain = chain
result = await stage.process(query, 'PreProcessor')
@@ -238,7 +241,7 @@ class TestPreProcessorImageSegment:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = image_query(text="describe this")
query = image_query(text='describe this')
result = await stage.process(query, 'PreProcessor')
@@ -276,7 +279,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
# Set pipeline config with primary model
query.pipeline_config = {
@@ -335,7 +338,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = {
'ai': {
@@ -384,7 +387,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello", sender_id=67890)
query = text_query('hello', sender_id=67890)
result = await stage.process(query, 'PreProcessor')
@@ -421,7 +424,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = group_text_query("hello", group_id=99999)
query = group_text_query('hello', group_id=99999)
result = await stage.process(query, 'PreProcessor')

View File

@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
'safety': {
'rate-limit': {
'window-length': 60, # 60 seconds window
'limitation': 10, # 10 requests per window
'limitation': 10, # 10 requests per window
'strategy': 'drop',
}
}
@@ -75,11 +75,9 @@ class TestFixedWindowAlgo:
# Make requests within limit
for i in range(10):
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
)
assert result is True, f"Request {i+1} should be allowed"
assert result is True, f'Request {i + 1} should be allowed'
@pytest.mark.asyncio
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -91,20 +89,12 @@ class TestFixedWindowAlgo:
# Exhaust the limit
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Next request should be denied
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
assert result is False, "Request exceeding limit should be denied"
assert result is False, 'Request exceeding limit should be denied'
@pytest.mark.asyncio
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -116,20 +106,14 @@ class TestFixedWindowAlgo:
# Exhaust limit for session 1
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session1'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
# Session 2 should still have its own limit
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session2'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
)
assert result is True, "Different session should have independent limit"
assert result is True, 'Different session should have independent limit'
@pytest.mark.asyncio
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
@@ -150,19 +134,11 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request allowed
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result1 is True
# Second request denied
result2 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result2 is False
@pytest.mark.asyncio
@@ -174,11 +150,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request creates container
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
expected_key = 'LauncherTypes.PERSON_12345'
@@ -230,7 +202,7 @@ class TestFixedWindowAlgo:
# New request should be allowed (new window)
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
assert result is True, "New window should allow new requests"
assert result is True, 'New window should allow new requests'
@pytest.mark.asyncio
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
@@ -256,29 +228,21 @@ class TestFixedWindowAlgo:
# First request allowed
start_time = time.time()
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
assert result1 is True
# Exhaust limit
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
# Third request should wait and then succeed
result3 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
elapsed = time.time() - start_time
assert result3 is True, "After wait, request should succeed"
assert result3 is True, 'After wait, request should succeed'
# Should have waited approximately until next window
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
# Note: This is a timing-sensitive test, so we use a generous tolerance
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
@pytest.mark.asyncio
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -289,11 +253,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# release_access is empty in current implementation
await algo.release_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Should not raise or change state
assert 'person_12345' not in algo.containers

View File

@@ -55,7 +55,7 @@ def make_session():
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
query.resp_message_chain = []
results = []
@@ -125,7 +123,7 @@ class TestResponseWrapperCommand:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -133,7 +131,7 @@ class TestResponseWrapperCommand:
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
)
query.resp_messages = [command_resp]
@@ -163,7 +161,7 @@ class TestResponseWrapperPlugin:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -171,7 +169,7 @@ class TestResponseWrapperPlugin:
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
)
query.resp_messages = [plugin_resp]
@@ -211,17 +209,17 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello back!"
assistant_resp.content = 'Hello back!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
)
query.resp_messages = [assistant_resp]
@@ -247,7 +245,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -292,7 +290,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -303,10 +301,10 @@ class TestResponseWrapperAssistant:
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Processing..."
assistant_resp.content = 'Processing...'
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
)
query.resp_messages = [assistant_resp]
@@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello!"
assistant_resp.content = 'Hello!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
)
query.resp_messages = [assistant_resp]
@@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply:
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
@@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Default reply"
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
@@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply:
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert "Custom reply" in str(chain)
assert 'Custom reply' in str(chain)
class TestResponseWrapperVariables:
@@ -452,7 +450,7 @@ class TestResponseWrapperVariables:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
@@ -460,10 +458,10 @@ class TestResponseWrapperVariables:
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello"
assistant_resp.content = 'Hello'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
)
query.resp_messages = [assistant_resp]

View File

@@ -6,6 +6,7 @@ Tests cover:
- RAG methods (ingest, retrieve, schema)
- Disabled plugin early returns
"""
from __future__ import annotations
import pytest
@@ -86,16 +87,12 @@ class TestListPlugins:
return_value=[
{
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Command'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}],
'debug': False,
},
{
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Tool'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}],
'debug': False,
},
]
@@ -127,9 +124,7 @@ class TestListPlugins:
},
]
)
connector.ap.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(__iter__=lambda self: iter([]))
)
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([])))
result = await connector.list_plugins()
@@ -230,7 +225,8 @@ class TestCallParser:
)
connector.handler.parse_document.assert_called_once_with(
'author', 'parser',
'author',
'parser',
{'mime_type': 'text/plain', 'filename': 'test.txt'},
b'file content',
)
@@ -251,9 +247,7 @@ class TestRAGMethods:
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
connector.handler.rag_ingest_document.assert_called_once_with(
'author', 'engine', {'file': 'test.pdf'}
)
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'})
assert result['status'] == 'success'
@pytest.mark.asyncio
@@ -264,14 +258,16 @@ class TestRAGMethods:
connector.handler = AsyncMock()
connector.handler.retrieve_knowledge = AsyncMock(
return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
return_value={
'results': [
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
]
}
)
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
connector.handler.retrieve_knowledge.assert_called_once_with(
'author', 'engine', '', {'query': 'test'}
)
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'})
assert result == {
'results': [
{
@@ -290,9 +286,7 @@ class TestRAGMethods:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}})
result = await connector.get_rag_creation_schema('author/engine')
@@ -326,9 +320,7 @@ class TestRAGMethods:
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
connector.handler.rag_on_kb_create.assert_called_once_with(
'author', 'engine', 'kb-uuid', {'model': 'test'}
)
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'})
@pytest.mark.asyncio
async def test_rag_on_kb_delete(self):
@@ -354,9 +346,7 @@ class TestRAGMethods:
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
connector.handler.rag_delete_document.assert_called_once_with(
'author', 'engine', 'doc-uuid', 'kb-uuid'
)
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid')
assert result is True
@@ -446,9 +436,7 @@ class TestGetPluginInfo:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_plugin_info = AsyncMock(
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
)
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}})
result = await connector.get_plugin_info('author', 'plugin')
@@ -470,9 +458,7 @@ class TestSetPluginConfig:
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
connector.handler.set_plugin_config.assert_called_once_with(
'author', 'plugin', {'setting': 'value'}
)
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'})
class TestPingPluginRuntime:

View File

@@ -3,6 +3,7 @@
Tests cover:
- _parse_plugin_id() parsing and validation
"""
from __future__ import annotations
import pytest

View File

@@ -6,6 +6,7 @@ Tests cover:
- Handling missing requirements.txt
- Handling empty/malformed requirements.txt
"""
from __future__ import annotations
import zipfile
@@ -82,13 +83,13 @@ class TestExtractDepsMetadata:
"""Test that comments and empty lines are filtered."""
connector_instance = create_mock_connector()
requirements = '''# This is a comment
requirements = """# This is a comment
requests>=2.0
# Another comment
flask==1.0
numpy'''
numpy"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()
@@ -147,9 +148,9 @@ numpy'''
"""Test handling requirements.txt with only comments."""
connector_instance = create_mock_connector()
requirements = '''# Comment 1
requirements = """# Comment 1
# Comment 2
# Comment 3'''
# Comment 3"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()

View File

@@ -40,11 +40,13 @@ class TestHandlerQueryVariables:
"""Test set_query_var returns error when query not found."""
runtime_handler = make_handler(mock_app)
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -58,11 +60,13 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code == 0
assert mock_query.variables['test_var'] == 'test_value'
@@ -76,10 +80,12 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'existing_var',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'existing_var',
}
)
assert response.code == 0
assert response.data == {'value': 'existing_value'}
@@ -93,9 +99,11 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
'query_id': 'test-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value](
{
'query_id': 'test-query',
}
)
assert response.code == 0
assert response.data == {'vars': mock_query.variables}
@@ -108,7 +116,7 @@ class TestHandlerRagErrorResponse:
"""Test basic error response creation."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = Exception("test error")
error = Exception('test error')
response = _make_rag_error_response(error, 'TestError')
# ActionResponse is a pydantic model, check message field
@@ -120,13 +128,8 @@ class TestHandlerRagErrorResponse:
"""Test error response with extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = ValueError("invalid input")
response = _make_rag_error_response(
error,
'ValidationError',
field='name',
value='test'
)
error = ValueError('invalid input')
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test')
assert 'ValidationError' in response.message
assert 'field=name' in response.message
@@ -137,7 +140,7 @@ class TestHandlerRagErrorResponse:
"""Test error response includes exception type."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = RuntimeError("connection failed")
error = RuntimeError('connection failed')
response = _make_rag_error_response(error, 'ConnectionError')
assert 'RuntimeError' in response.message
@@ -148,7 +151,7 @@ class TestHandlerRagErrorResponse:
"""Test error response with no extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = KeyError("missing_key")
error = KeyError('missing_key')
response = _make_rag_error_response(error, 'LookupError')
# No context parts means no brackets

View File

@@ -47,12 +47,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -82,12 +84,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 3
@@ -161,9 +165,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'new')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new'))
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -203,9 +205,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'x')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x'))
assert response.code != 0
assert '1 > 0 bytes' in response.message
@@ -228,10 +228,12 @@ class TestGetPluginSettings:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -255,10 +257,12 @@ class TestGetPluginSettings:
)
app.persistence_mgr.execute_async.return_value = make_result(setting)
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -286,11 +290,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code == 0
assert response.data == {
@@ -303,11 +309,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code != 0
assert 'Storage with key test-key not found' in response.message
@@ -329,9 +337,11 @@ class TestHandlerQueryLookup:
"""Query-bound actions return error when query_id is not cached."""
runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'nonexistent-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'nonexistent-query',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -343,9 +353,11 @@ class TestHandlerQueryLookup:
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
app.query_pool.cached_queries['existing-query'] = query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'existing-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'existing-query',
}
)
assert response.code == 0
assert response.data == {'bot_uuid': 'test-bot-uuid'}

View File

@@ -4,6 +4,7 @@ Tests cover:
- _make_rag_error_response() helper function
- RuntimeConnectionHandler cleanup_plugin_data method
"""
from __future__ import annotations
import pytest
@@ -23,7 +24,7 @@ class TestMakeRagErrorResponse:
"""Test basic error response creation."""
handler = get_handler_module()
error = ValueError("test error message")
error = ValueError('test error message')
result = handler._make_rag_error_response(error, 'TestError')
# ActionResponse.error() returns code=1 (error status)
@@ -36,7 +37,7 @@ class TestMakeRagErrorResponse:
"""Test that error type is included in message."""
handler = get_handler_module()
error = RuntimeError("something went wrong")
error = RuntimeError('something went wrong')
result = handler._make_rag_error_response(error, 'VectorStoreError')
assert '[VectorStoreError/RuntimeError]' in result.message
@@ -45,7 +46,7 @@ class TestMakeRagErrorResponse:
"""Test that extra context fields are included."""
handler = get_handler_module()
error = Exception("embedding failed")
error = Exception('embedding failed')
result = handler._make_rag_error_response(
error,
'EmbeddingError',
@@ -71,7 +72,7 @@ class TestMakeRagErrorResponse:
"""Test multiple context fields are comma separated."""
handler = get_handler_module()
error = IOError("file not found")
error = IOError('file not found')
result = handler._make_rag_error_response(
error,
'FileServiceError',
@@ -119,9 +120,7 @@ class TestCleanupPluginData:
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
handler_instance.ap = mock_app
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
handler_instance, 'author', 'plugin-name'
)
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name')
# Should have at least 2 calls: one for settings, one for binary storage
assert mock_app.persistence_mgr.execute_async.call_count >= 2
assert mock_app.persistence_mgr.execute_async.call_count >= 2

View File

@@ -88,7 +88,10 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
return provider_message.Message(
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
)
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results."""
@@ -135,8 +138,10 @@ def mock_app_for_modelmgr():
# Fake persistence manager - returns empty results by default
app.persistence_mgr = SimpleNamespace()
async def default_execute(query):
return _make_mock_result([])
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
# Fake discover engine
@@ -165,9 +170,7 @@ def fake_requester_registry(mock_app_for_modelmgr):
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
app.discover.get_components_by_kind = Mock(
return_value=[fake_component, another_component]
)
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component])
model_mgr = ModelManager(app)
return model_mgr

View File

@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
with pytest.raises(DifyAPIError, match='不支持'):
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app
assert runner.ap == mock_app

View File

@@ -1062,9 +1062,7 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.side_effect = (
lambda model: {'max_input_tokens': 131072}
if model == 'moonshot/moonshot-v1-128k'
else {}
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {}
)
assert requester._safe_context_length('moonshot-v1-128k') == 131072

View File

@@ -635,7 +635,9 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
model_mgr = fake_requester_registry
@@ -648,7 +650,9 @@ async def test_model_manager_load_llm_model_with_provider(fake_requester_registr
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider_from_row(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
model_mgr = fake_requester_registry
@@ -661,7 +665,9 @@ async def test_model_manager_load_llm_model_with_provider_from_row(fake_requeste
@pytest.mark.asyncio
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_embedding_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
model_mgr = fake_requester_registry

View File

@@ -43,6 +43,7 @@ class TestableRequester(requester.ProviderAPIRequester):
remove_think=False,
):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Testable response')],
@@ -289,7 +290,9 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
result = await provider.invoke_llm(query, runtime_llm_model, messages)
@@ -330,7 +333,9 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
chunks = []
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
@@ -576,7 +581,9 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
with pytest.raises(RequesterError):
await provider.invoke_llm(query, model, messages)

View File

@@ -5,6 +5,7 @@ Tests cover:
- Conversation creation with prompts
- Session concurrency semaphore
"""
from __future__ import annotations
import pytest
@@ -60,11 +61,7 @@ class TestSessionManagerGetSession:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -173,11 +170,7 @@ class TestSessionManagerGetConversation:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -201,17 +194,13 @@ class TestSessionManagerGetConversation:
return query
@pytest.mark.asyncio
async def test_creates_conversation_with_prompt(
self, mock_app_with_config, sample_query, sample_session
):
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session):
"""Test that get_conversation creates conversation with prompt."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
@@ -234,21 +223,15 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
# First call creates conversation
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
# Second call with same pipeline should return same conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
assert conv1 is conv2
assert len(sample_session.conversations) == 1
@@ -262,36 +245,26 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
# First call with pipeline1
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1')
# Second call with different pipeline should create new conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2')
assert conv1 is not conv2
assert len(sample_session.conversations) == 2
assert sample_session.using_conversation is conv2
@pytest.mark.asyncio
async def test_conversation_has_empty_messages(
self, mock_app_with_config, sample_query, sample_session
):
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session):
"""Test that created conversation has empty messages list."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
@@ -300,22 +273,17 @@ class TestSessionManagerGetConversation:
assert conversation.messages == []
@pytest.mark.asyncio
async def test_prompt_messages_from_config(
self, mock_app_with_config, sample_query, sample_session
):
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session):
"""Test that prompt messages are created from prompt_config."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'System message'},
{'role': 'user', 'content': 'User message'}
]
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
)
assert conversation.prompt.name == 'default'
assert len(conversation.prompt.messages) == 2
assert len(conversation.prompt.messages) == 2

View File

@@ -136,6 +136,7 @@ class TestToolManagerSchemaGeneration:
assert 'description' in func
assert 'parameters' in func
class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method."""

View File

@@ -3,6 +3,7 @@
Tests cover:
- _to_i18n_name() static method
"""
from __future__ import annotations
from importlib import import_module
@@ -60,4 +61,4 @@ class TestToI18nName:
kbmgr = get_kbmgr_module()
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
result = kbmgr.RAGManager._to_i18n_name(input_dict)
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}

View File

@@ -6,6 +6,7 @@ Tests cover:
- Knowledge engine enrichment
- KB loading and removal
"""
from __future__ import annotations
import pytest
@@ -101,13 +102,9 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin error')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin error'))
manager = rag_module.RAGManager(mock_app)
@@ -128,9 +125,7 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock()
@@ -206,9 +201,7 @@ class TestRuntimeKnowledgeBaseOnKBCreate:
mock_app = create_mock_app()
mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin failed')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin failed'))
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -245,9 +238,7 @@ class TestRuntimeKnowledgeBaseIngestDocument:
mock_app = create_mock_app()
mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.call_rag_ingest = AsyncMock(
return_value={'status': 'success'}
)
mock_app.plugin_connector.call_rag_ingest = AsyncMock(return_value={'status': 'success'})
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -304,14 +295,10 @@ class TestRAGManagerLoadKnowledgeBasesFromDB:
# KB that will cause initialize to fail
mock_kb = create_mock_kb_entity()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb])))
# Make initialize fail by having plugin_connector throw error
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Init failed')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Init failed'))
manager = rag_module.RAGManager(mock_app)
# Should not raise - errors are caught
@@ -411,9 +398,7 @@ class TestRuntimeKnowledgeBaseRetrieve:
mock_kb = create_mock_kb_entity()
mock_kb.retrieval_settings = {}
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(
return_value={'results': []}
)
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(return_value={'results': []})
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -682,9 +667,7 @@ class TestRAGManagerGetAllDetails:
"""Test returns empty list when no knowledge bases."""
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = rag_module.RAGManager(mock_app)
result = await manager.get_all_knowledge_base_details()
@@ -699,9 +682,7 @@ class TestRAGManagerGetAllDetails:
# Mock DB result
mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb_row]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb_row])))
mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
)
@@ -724,9 +705,7 @@ class TestRAGManagerGetDetails:
"""Test returns None when KB doesn't exist."""
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=None))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=None)))
manager = rag_module.RAGManager(mock_app)
result = await manager.get_knowledge_base_details('nonexistent')
@@ -740,9 +719,7 @@ class TestRAGManagerGetDetails:
mock_app = create_mock_app()
mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=mock_kb_row))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=mock_kb_row)))
mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
)
@@ -791,4 +768,4 @@ class TestRAGManagerLoadKnowledgeBase:
await manager.load_knowledge_base(kb_dict)
assert 'kb-uuid' in manager.knowledge_bases
assert 'kb-uuid' in manager.knowledge_bases

View File

@@ -121,10 +121,12 @@ class TestRAGRuntimeServiceVectorSearch:
"""Create mock app."""
mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.search = AsyncMock(return_value=[
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
])
mock_app.vector_db_mgr.search = AsyncMock(
return_value=[
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
]
)
return mock_app
def _make_rag_import_mocks(self):
@@ -301,10 +303,7 @@ class TestRAGRuntimeServiceVectorList:
mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.list_by_filter = AsyncMock(
return_value=(
[{'id': 'id1', 'metadata': {'file_id': 'abc'}}],
10
)
return_value=([{'id': 'id1', 'metadata': {'file_id': 'abc'}}], 10)
)
return mock_app

View File

@@ -21,8 +21,8 @@ from langbot.pkg.storage.providers.localstorage import LocalStorageProvider
@pytest.fixture
def storage_provider(tmp_path):
"""Create a LocalStorageProvider with a temporary storage path."""
storage_path = str(tmp_path / "storage")
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
storage_path = str(tmp_path / 'storage')
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
mock_app = Mock()
provider = LocalStorageProvider(mock_app)
yield provider, storage_path
@@ -35,15 +35,15 @@ class TestPathTraversalPrevention:
async def test_absolute_path_save_rejected(self, storage_provider, tmp_path):
"""Saving with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "pwned.txt")
target_file = str(tmp_path / 'pwned.txt')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError)):
await provider.save(target_file, b"malicious content")
await provider.save(target_file, b'malicious content')
# The file must NOT exist outside the storage directory
assert not os.path.exists(target_file), (
f"Path traversal succeeded: file was written outside storage to {target_file}"
f'Path traversal succeeded: file was written outside storage to {target_file}'
)
@pytest.mark.asyncio
@@ -52,32 +52,28 @@ class TestPathTraversalPrevention:
provider, storage_path = storage_provider
# Create a file outside the storage directory
target_file = str(tmp_path / "secret.txt")
with open(target_file, "wb") as f:
f.write(b"secret data")
target_file = str(tmp_path / 'secret.txt')
with open(target_file, 'wb') as f:
f.write(b'secret data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(target_file)
assert data != b"secret data", (
"Path traversal succeeded: read file outside storage"
)
assert data != b'secret data', 'Path traversal succeeded: read file outside storage'
@pytest.mark.asyncio
async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path):
"""Exists check with an absolute path key must be blocked or return False."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "check_me.txt")
with open(target_file, "wb") as f:
f.write(b"data")
target_file = str(tmp_path / 'check_me.txt')
with open(target_file, 'wb') as f:
f.write(b'data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
try:
result = await provider.exists(target_file)
assert result is False, (
"Path traversal succeeded: exists() returned True for file outside storage"
)
assert result is False, 'Path traversal succeeded: exists() returned True for file outside storage'
except (ValueError, PermissionError):
pass # Expected
@@ -86,28 +82,26 @@ class TestPathTraversalPrevention:
"""Deleting with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "do_not_delete.txt")
with open(target_file, "wb") as f:
f.write(b"important data")
target_file = str(tmp_path / 'do_not_delete.txt')
with open(target_file, 'wb') as f:
f.write(b'important data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.delete(target_file)
assert os.path.exists(target_file), (
"Path traversal succeeded: file outside storage was deleted"
)
assert os.path.exists(target_file), 'Path traversal succeeded: file outside storage was deleted'
@pytest.mark.asyncio
async def test_absolute_path_size_rejected(self, storage_provider, tmp_path):
"""Size check with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "measure_me.txt")
with open(target_file, "wb") as f:
f.write(b"some data")
target_file = str(tmp_path / 'measure_me.txt')
with open(target_file, 'wb') as f:
f.write(b'some data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.size(target_file)
@@ -116,41 +110,39 @@ class TestPathTraversalPrevention:
"""Relative path traversal with '..' must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "above_storage.txt")
with open(target_file, "wb") as f:
f.write(b"above storage secret")
target_file = str(tmp_path / 'above_storage.txt')
with open(target_file, 'wb') as f:
f.write(b'above storage secret')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
relative_key = os.path.join("..", "above_storage.txt")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
relative_key = os.path.join('..', 'above_storage.txt')
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(relative_key)
assert data != b"above storage secret"
assert data != b'above storage secret'
@pytest.mark.asyncio
async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path):
"""delete_dir_recursive with traversal path must be blocked."""
provider, storage_path = storage_provider
outside_dir = tmp_path / "outside_dir"
outside_dir = tmp_path / 'outside_dir'
outside_dir.mkdir()
(outside_dir / "file.txt").write_text("important")
(outside_dir / 'file.txt').write_text('important')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError)):
await provider.delete_dir_recursive(str(outside_dir))
assert outside_dir.exists(), (
"Path traversal succeeded: directory outside storage was deleted"
)
assert outside_dir.exists(), 'Path traversal succeeded: directory outside storage was deleted'
@pytest.mark.asyncio
async def test_legitimate_key_works(self, storage_provider):
"""Normal keys without traversal must still work."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = "test_image_abc123.png"
content = b"PNG image data"
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
key = 'test_image_abc123.png'
content = b'PNG image data'
await provider.save(key, content)
assert await provider.exists(key) is True
@@ -166,9 +158,9 @@ class TestPathTraversalPrevention:
"""Keys with legitimate subdirectories must still work."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = "bot_log_images/img_001.png"
content = b"PNG image data"
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
key = 'bot_log_images/img_001.png'
content = b'PNG image data'
await provider.save(key, content)
assert await provider.exists(key) is True
@@ -181,33 +173,33 @@ class TestPathTraversalPrevention:
"""delete_dir_recursive should handle non-existing directories gracefully."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
# Try to delete a non-existing directory - should not raise
await provider.delete_dir_recursive("nonexistent_dir")
await provider.delete_dir_recursive('nonexistent_dir')
@pytest.mark.asyncio
async def test_delete_dir_recursive_with_files(self, storage_provider):
"""delete_dir_recursive should delete directory with files inside."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
# Create a directory with files
key1 = "test_dir/file1.txt"
key2 = "test_dir/file2.txt"
await provider.save(key1, b"content1")
await provider.save(key2, b"content2")
key1 = 'test_dir/file1.txt'
key2 = 'test_dir/file2.txt'
await provider.save(key1, b'content1')
await provider.save(key2, b'content2')
# Verify files exist
assert await provider.exists(key1)
assert await provider.exists(key2)
# Delete directory recursively
await provider.delete_dir_recursive("test_dir")
await provider.delete_dir_recursive('test_dir')
# Verify files no longer exist
assert not await provider.exists(key1)
assert not await provider.exists(key2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -8,6 +8,7 @@ Tests cover:
Uses moto library to mock AWS S3 service.
"""
from __future__ import annotations
import pytest
@@ -44,8 +45,10 @@ def mock_app_with_s3_config():
def s3_mock():
"""Set up moto S3 mock context."""
from moto import mock_aws
with mock_aws():
import boto3
# Create bucket for tests that need pre-existing bucket
s3 = boto3.client('s3', region_name='us-east-1')
yield s3
@@ -325,4 +328,4 @@ class TestS3StorageProviderErrorHandling:
await provider.initialize()
with pytest.raises(Exception):
await provider.size('nonexistent.txt')
await provider.size('nonexistent.txt')

View File

@@ -31,7 +31,7 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
mock_app.logger.info.assert_called()
@@ -41,12 +41,12 @@ class TestStorageMgr:
"""Should use local storage when explicitly configured."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {"storage": {"use": "local"}}
mock_app.instance_config.data = {'storage': {'use': 'local'}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -55,14 +55,12 @@ class TestStorageMgr:
"""Should use S3 storage when configured."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
"storage": {"use": "s3", "s3": {"endpoint_url": "https://s3.amazonaws.com"}}
}
mock_app.instance_config.data = {'storage': {'use': 's3', 's3': {'endpoint_url': 'https://s3.amazonaws.com'}}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(S3StorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, S3StorageProvider)
@@ -71,12 +69,12 @@ class TestStorageMgr:
"""Should default to local storage for invalid storage type."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {"storage": {"use": "invalid_type"}}
mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -90,9 +88,7 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app)
with patch.object(
LocalStorageProvider, "initialize", new_callable=AsyncMock
) as mock_init:
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init:
await storage_mgr.initialize()
mock_init.assert_called_once()
@@ -105,8 +101,8 @@ class TestStorageProviderBase:
mock_app = Mock()
# Use LocalStorageProvider as concrete implementation
with patch("os.path.exists", return_value=True):
with patch("os.makedirs"):
with patch('os.path.exists', return_value=True):
with patch('os.makedirs'):
provider = LocalStorageProvider(mock_app)
assert provider.ap == mock_app
@@ -115,12 +111,12 @@ class TestStorageProviderBase:
"""Provider base initialize should be callable and do nothing."""
mock_app = Mock()
with patch("os.path.exists", return_value=True):
with patch("os.makedirs"):
with patch('os.path.exists', return_value=True):
with patch('os.makedirs'):
provider = LocalStorageProvider(mock_app)
# Initialize should not raise
await provider.initialize()
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -8,6 +8,7 @@ Tests cover:
- HTTP request success/failure scenarios
- Source code bug: send_tasks should be instance variable
"""
from __future__ import annotations
import pytest
@@ -38,6 +39,7 @@ class TestTelemetryManagerInit:
manager = telemetry.TelemetryManager(mock_app)
assert manager.telemetry_config == {}
class TestTelemetryManagerInitialize:
"""Tests for initialize() method."""
@@ -218,7 +220,7 @@ class TestPayloadSanitization:
# All null string fields should be empty strings
for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']:
assert result[field] == '', f"Field {field} should be empty string, got {result[field]}"
assert result[field] == '', f'Field {field} should be empty string, got {result[field]}'
@pytest.mark.asyncio
async def test_sanitize_string_fields_preserve_values(self):
@@ -418,9 +420,7 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock(
status_code=200,
text='{"code": 0, "msg": "success"}',
json=Mock(return_value={'code': 0, 'msg': 'success'})
status_code=200, text='{"code": 0, "msg": "success"}', json=Mock(return_value={'code': 0, 'msg': 'success'})
)
mock_client = Mock()
@@ -448,9 +448,7 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock(
status_code=500,
text='Internal Server Error',
json=Mock(return_value={'code': 500, 'msg': 'error'})
status_code=500, text='Internal Server Error', json=Mock(return_value={'code': 500, 'msg': 'error'})
)
mock_client = Mock()
@@ -478,7 +476,7 @@ class TestHTTPScenarios:
mock_response = Mock(
status_code=200,
text='{"code": 400, "msg": "Bad Request"}',
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'})
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}),
)
mock_client = Mock()
@@ -493,7 +491,7 @@ class TestHTTPScenarios:
assert mock_app.logger.warning.call_count >= 1
# Check that one of the calls contains application error info
all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list]
assert any('400' in w for w in all_warnings), f"No warning contained error code 400: {all_warnings}"
assert any('400' in w for w in all_warnings), f'No warning contained error code 400: {all_warnings}'
@pytest.mark.asyncio
async def test_send_timeout_logs_warning(self):

View File

@@ -9,6 +9,7 @@ Tests cover:
Note: Do NOT use 'from __future__ import annotations' because
funcschema.py expects actual type objects, not string annotations.
"""
import pytest
from importlib import import_module

View File

@@ -20,55 +20,53 @@ class TestGetQQImageDownloadableUrl:
def test_basic_url(self):
"""Parse basic image URL."""
url = "http://example.com/image.jpg"
url = 'http://example.com/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/image.jpg"
assert result_url == 'http://example.com/image.jpg'
assert query == {}
def test_url_with_query_params(self):
"""Parse URL with query parameters."""
url = "http://example.com/image.jpg?param1=value1&param2=value2"
url = 'http://example.com/image.jpg?param1=value1&param2=value2'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/image.jpg"
assert query == {"param1": ["value1"], "param2": ["value2"]}
assert result_url == 'http://example.com/image.jpg'
assert query == {'param1': ['value1'], 'param2': ['value2']}
def test_url_with_port(self):
"""Parse URL with port number."""
url = "http://example.com:8080/image.jpg"
url = 'http://example.com:8080/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com:8080/image.jpg"
assert result_url == 'http://example.com:8080/image.jpg'
def test_url_with_path(self):
"""Parse URL with complex path."""
url = "http://example.com/path/to/image.jpg"
url = 'http://example.com/path/to/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/path/to/image.jpg"
assert result_url == 'http://example.com/path/to/image.jpg'
def test_url_with_fragment(self):
"""Parse URL with fragment (fragment is not part of query)."""
url = "http://example.com/image.jpg#fragment"
url = 'http://example.com/image.jpg#fragment'
result_url, query = get_qq_image_downloadable_url(url)
# Fragment is not included in query string parsing
assert "http://example.com/image.jpg" in result_url
assert 'http://example.com/image.jpg' in result_url
def test_https_url(self):
"""Parse HTTPS URL and preserve its scheme."""
url = "https://example.com/image.jpg"
url = 'https://example.com/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "https://example.com/image.jpg"
assert result_url == 'https://example.com/image.jpg'
assert query == {}
def test_preserves_qq_https_scheme_and_query(self):
"""QQ image URLs keep HTTPS and query parameters."""
result_url, query = get_qq_image_downloadable_url(
'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1'
)
result_url, query = get_qq_image_downloadable_url('https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1')
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
assert query == {'term': ['2'], 'is_origin': ['1']}
@@ -88,50 +86,50 @@ class TestExtractB64AndFormat:
async def test_jpeg_data_uri(self):
"""Extract base64 and format from JPEG data URI."""
# Create a simple base64 string
original_data = b"test image data"
original_data = b'test image data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/jpeg;base64,{b64_data}"
data_uri = f'data:image/jpeg;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "jpeg"
assert result_format == 'jpeg'
@pytest.mark.asyncio
async def test_png_data_uri(self):
"""Extract base64 and format from PNG data URI."""
original_data = b"test png data"
original_data = b'test png data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/png;base64,{b64_data}"
data_uri = f'data:image/png;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "png"
assert result_format == 'png'
@pytest.mark.asyncio
async def test_gif_data_uri(self):
"""Extract base64 and format from GIF data URI."""
original_data = b"test gif data"
original_data = b'test gif data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/gif;base64,{b64_data}"
data_uri = f'data:image/gif;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "gif"
assert result_format == 'gif'
@pytest.mark.asyncio
async def test_webp_data_uri(self):
"""Extract base64 and format from WebP data URI."""
original_data = b"test webp data"
original_data = b'test webp data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/webp;base64,{b64_data}"
data_uri = f'data:image/webp;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "webp"
assert result_format == 'webp'
@pytest.mark.asyncio
async def test_complex_base64(self):
@@ -139,7 +137,7 @@ class TestExtractB64AndFormat:
# Base64 can include + and / characters
original_data = bytes(range(256)) # All byte values
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/png;base64,{b64_data}"
data_uri = f'data:image/png;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
@@ -150,9 +148,9 @@ class TestExtractB64AndFormat:
@pytest.mark.asyncio
async def test_empty_base64(self):
"""Handle empty base64 string."""
data_uri = "data:image/png;base64,"
data_uri = 'data:image/png;base64,'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == ""
assert result_format == "png"
assert result_b64 == ''
assert result_format == 'png'

View File

@@ -23,52 +23,52 @@ class TestImportDir:
def test_calls_importlib_for_each_python_file(self, tmp_path):
"""Should call importlib.import_module for each .py file."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "__init__.py").write_text("")
(module_dir / "module_a.py").write_text("VALUE_A = 'a'\n")
(module_dir / "module_b.py").write_text("VALUE_B = 'b'\n")
(module_dir / "readme.txt").write_text("not a module")
(module_dir / '__init__.py').write_text('')
(module_dir / 'module_a.py').write_text("VALUE_A = 'a'\n")
(module_dir / 'module_b.py').write_text("VALUE_B = 'b'\n")
(module_dir / 'readme.txt').write_text('not a module')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# Should call import_module for each .py file (excluding __init__.py)
assert mock_import.call_count == 2
def test_skips_init_py(self, tmp_path):
"""Should skip __init__.py when importing."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "__init__.py").write_text("")
(module_dir / "regular.py").write_text("VALUE = 1\n")
(module_dir / '__init__.py').write_text('')
(module_dir / 'regular.py').write_text('VALUE = 1\n')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# __init__.py should be skipped
mock_import.assert_called_once()
# The call should not include __init__
call_args = mock_import.call_args[0][0]
assert "__init__" not in call_args
assert '__init__' not in call_args
def test_ignores_non_py_files(self, tmp_path):
"""Should ignore non-.py files."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "module.py").write_text("VALUE = 1\n")
(module_dir / "readme.txt").write_text("text")
(module_dir / "data.json").write_text("{}")
(module_dir / 'module.py').write_text('VALUE = 1\n')
(module_dir / 'readme.txt').write_text('text')
(module_dir / 'data.json').write_text('{}')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# Only .py files should be imported
assert mock_import.call_count == 1
@@ -79,14 +79,14 @@ class TestImportModulesInPkg:
def test_imports_modules_from_package(self, tmp_path):
"""Should import all modules from a package object."""
mock_pkg = MagicMock()
mock_pkg.__file__ = str(tmp_path / "__init__.py")
mock_pkg.__file__ = str(tmp_path / '__init__.py')
(tmp_path / "__init__.py").write_text("")
(tmp_path / "mod1.py").write_text("MOD1 = 1\n")
(tmp_path / '__init__.py').write_text('')
(tmp_path / 'mod1.py').write_text('MOD1 = 1\n')
from langbot.pkg.utils import importutil
with patch.object(importutil, "import_dir") as mock_import_dir:
with patch.object(importutil, 'import_dir') as mock_import_dir:
importutil.import_modules_in_pkg(mock_pkg)
mock_import_dir.assert_called_once()
call_path = mock_import_dir.call_args[0][0]
@@ -101,11 +101,11 @@ class TestImportModulesInPkgs:
from langbot.pkg.utils import importutil
mock_pkg1 = MagicMock()
mock_pkg1.__file__ = "/path/to/pkg1/__init__.py"
mock_pkg1.__file__ = '/path/to/pkg1/__init__.py'
mock_pkg2 = MagicMock()
mock_pkg2.__file__ = "/path/to/pkg2/__init__.py"
mock_pkg2.__file__ = '/path/to/pkg2/__init__.py'
with patch.object(importutil, "import_modules_in_pkg") as mock_import:
with patch.object(importutil, 'import_modules_in_pkg') as mock_import:
importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2])
assert mock_import.call_count == 2
@@ -116,18 +116,18 @@ class TestImportDotStyleDir:
def test_converts_dot_notation_to_path(self, tmp_path):
"""Should convert dot notation to path and import."""
# Create structure matching the dot notation
(tmp_path / "my").mkdir()
(tmp_path / "my" / "pkg").mkdir()
(tmp_path / "my" / "pkg" / "test").mkdir()
(tmp_path / 'my').mkdir()
(tmp_path / 'my' / 'pkg').mkdir()
(tmp_path / 'my' / 'pkg' / 'test').mkdir()
from langbot.pkg.utils import importutil
with patch.object(importutil, "import_dir") as mock_import_dir:
importutil.import_dot_style_dir("my.pkg.test")
with patch.object(importutil, 'import_dir') as mock_import_dir:
importutil.import_dot_style_dir('my.pkg.test')
# The path should be converted using os.path.join
call_path = mock_import_dir.call_args[0][0]
# Should contain the path components joined
assert "my" in call_path
assert 'my' in call_path
class TestReadResourceFile:
@@ -137,16 +137,16 @@ class TestReadResourceFile:
"""Should read content from a resource file."""
from langbot.pkg.utils import importutil
content = importutil.read_resource_file("templates/config.yaml")
assert "admins:" in content
assert "edition: community" in content
content = importutil.read_resource_file('templates/config.yaml')
assert 'admins:' in content
assert 'edition: community' in content
def test_raises_for_nonexistent_file(self):
"""Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file("nonexistent/path/file.txt")
importutil.read_resource_file('nonexistent/path/file.txt')
class TestReadResourceFileBytes:
@@ -156,16 +156,16 @@ class TestReadResourceFileBytes:
"""Should read content as bytes from a resource file."""
from langbot.pkg.utils import importutil
content = importutil.read_resource_file_bytes("templates/config.yaml")
assert b"admins:" in content
assert b"edition: community" in content
content = importutil.read_resource_file_bytes('templates/config.yaml')
assert b'admins:' in content
assert b'edition: community' in content
def test_raises_for_nonexistent_file_bytes(self):
"""Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file_bytes("nonexistent/path/file.txt")
importutil.read_resource_file_bytes('nonexistent/path/file.txt')
class TestListResourceFiles:
@@ -175,9 +175,9 @@ class TestListResourceFiles:
"""Should list files in a resource directory."""
from langbot.pkg.utils import importutil
files = importutil.list_resource_files("templates")
assert "config.yaml" in files
assert "default-pipeline-config.json" in files
files = importutil.list_resource_files('templates')
assert 'config.yaml' in files
assert 'default-pipeline-config.json' in files
assert all(isinstance(file, str) for file in files)
def test_raises_for_nonexistent_directory(self):
@@ -185,8 +185,8 @@ class TestListResourceFiles:
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.list_resource_files("nonexistent_directory_xyz")
importutil.list_resource_files('nonexistent_directory_xyz')
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -5,6 +5,7 @@ Tests cover:
- Docker environment detection
- WebSocket plugin runtime mode
"""
from __future__ import annotations
import os
@@ -86,4 +87,4 @@ class TestGetPlatform:
assert platform_module.use_websocket_to_connect_plugin_runtime() is True
# Restore
platform_module.standalone_runtime = original
platform_module.standalone_runtime = original

View File

@@ -60,10 +60,12 @@ class TestProxyManager:
async def test_initialize_config_overrides_env(self):
"""Config proxy overrides environment variables."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://config-proxy:8080',
'https': 'https://config-proxy:8443',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://config-proxy:8080',
'https': 'https://config-proxy:8443',
}
)
with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}):
pm = ProxyManager(mock_app)
@@ -74,10 +76,12 @@ class TestProxyManager:
async def test_initialize_sets_env_variables(self):
"""initialize sets proxy to environment variables."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://test-proxy:8080',
'https': 'https://test-proxy:8443',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://test-proxy:8080',
'https': 'https://test-proxy:8443',
}
)
pm = ProxyManager(mock_app)
await pm.initialize()
@@ -143,9 +147,11 @@ class TestProxyManager:
async def test_initialize_http_only_config(self):
"""initialize handles http-only config."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://http-only:8080',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://http-only:8080',
}
)
# Clear any existing proxy env vars
env_backup = {}

View File

@@ -29,63 +29,63 @@ class TestGetRunnerCategory:
def test_empty_url_returns_unknown(self):
"""Empty or None URL should return UNKNOWN."""
assert get_runner_category("test", "") == RunnerCategory.UNKNOWN
assert get_runner_category("test", None) == RunnerCategory.UNKNOWN
assert get_runner_category('test', '') == RunnerCategory.UNKNOWN
assert get_runner_category('test', None) == RunnerCategory.UNKNOWN
def test_localhost_returns_local(self):
"""localhost URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://localhost:3000") == RunnerCategory.LOCAL
assert get_runner_category("test", "https://localhost") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://localhost:3000') == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://localhost') == RunnerCategory.LOCAL
def test_127_0_0_1_returns_local(self):
"""127.0.0.1 URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://127.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "https://127.0.0.1") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://127.0.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://127.0.0.1') == RunnerCategory.LOCAL
def test_0_0_0_0_returns_local(self):
"""0.0.0.0 URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://0.0.0.0:8080") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://0.0.0.0:8080') == RunnerCategory.LOCAL
def test_private_ip_192_168_returns_local(self):
"""192.168.x.x private IP should be categorized as LOCAL."""
assert get_runner_category("test", "http://192.168.1.1:3000") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://192.168.0.100") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://192.168.1.1:3000') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://192.168.0.100') == RunnerCategory.LOCAL
def test_private_ip_10_returns_local(self):
"""10.x.x.x private IP should be categorized as LOCAL."""
assert get_runner_category("test", "http://10.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://10.255.255.255") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://10.0.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://10.255.255.255') == RunnerCategory.LOCAL
def test_private_ip_172_16_31_returns_local(self):
"""172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL."""
assert get_runner_category("test", "http://172.16.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://172.20.0.1") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://172.31.255.255") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.16.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.20.0.1') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.31.255.255') == RunnerCategory.LOCAL
def test_n8n_cloud_returns_cloud(self):
"""n8n.cloud domain should be categorized as CLOUD."""
assert get_runner_category("test", "https://myinstance.n8n.cloud") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://test.n8n.io") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://myinstance.n8n.cloud') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.n8n.io') == RunnerCategory.CLOUD
def test_dify_cloud_returns_cloud(self):
"""Dify cloud domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://api.dify.ai/v1") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://cloud.dify.ai") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.dify.ai/v1') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://cloud.dify.ai') == RunnerCategory.CLOUD
def test_coze_cloud_returns_cloud(self):
"""Coze domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://api.coze.com") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://api.coze.cn") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.coze.com') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.coze.cn') == RunnerCategory.CLOUD
def test_langflow_cloud_returns_cloud(self):
"""Langflow domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://cloud.langflow.ai") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://test.langflow.org") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://cloud.langflow.ai') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.langflow.org') == RunnerCategory.CLOUD
def test_other_url_returns_cloud(self):
"""Other URLs should default to CLOUD category."""
assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://example.com') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://myserver.example.org') == RunnerCategory.CLOUD
@pytest.mark.parametrize(
'runner_url',
@@ -101,7 +101,7 @@ class TestGetRunnerCategory:
)
def test_invalid_urls_return_unknown(self, runner_url):
"""Invalid or incomplete URLs should return UNKNOWN."""
assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
assert get_runner_category('test', runner_url) == RunnerCategory.UNKNOWN
def test_urlparse_exception_returns_unknown(self):
"""Exception during URL parsing should return UNKNOWN."""
@@ -109,15 +109,15 @@ class TestGetRunnerCategory:
from langbot.pkg.utils import runner
def mock_urlparse(url):
raise Exception("URL parsing failed")
raise Exception('URL parsing failed')
with patch("langbot.pkg.utils.runner.urlparse", side_effect=mock_urlparse):
result = runner.get_runner_category("test", "http://example.com")
with patch('langbot.pkg.utils.runner.urlparse', side_effect=mock_urlparse):
result = runner.get_runner_category('test', 'http://example.com')
assert result == RunnerCategory.UNKNOWN
def test_url_without_scheme_returns_unknown(self):
"""URL without scheme should return UNKNOWN."""
assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN
assert get_runner_category('test', 'example.com') == RunnerCategory.UNKNOWN
@pytest.mark.parametrize(
'runner_url',
@@ -146,20 +146,21 @@ class TestGetRunnerCategory:
"""Domain names that only look like private IP prefixes should not be LOCAL."""
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
class TestIsCloudRunner:
"""Test is_cloud_runner helper function."""
def test_cloud_runner_returns_true(self):
"""Cloud URL should return True."""
assert is_cloud_runner("test", "https://api.dify.ai") is True
assert is_cloud_runner('test', 'https://api.dify.ai') is True
def test_local_runner_returns_false(self):
"""Local URL should return False."""
assert is_cloud_runner("test", "http://localhost:3000") is False
assert is_cloud_runner('test', 'http://localhost:3000') is False
def test_unknown_returns_false(self):
"""Unknown category should return False."""
assert is_cloud_runner("test", None) is False
assert is_cloud_runner('test', None) is False
class TestIsLocalRunner:
@@ -167,15 +168,15 @@ class TestIsLocalRunner:
def test_local_runner_returns_true(self):
"""Local URL should return True."""
assert is_local_runner("test", "http://localhost:3000") is True
assert is_local_runner('test', 'http://localhost:3000') is True
def test_cloud_runner_returns_false(self):
"""Cloud URL should return False."""
assert is_local_runner("test", "https://api.dify.ai") is False
assert is_local_runner('test', 'https://api.dify.ai') is False
def test_unknown_returns_false(self):
"""Unknown category should return False."""
assert is_local_runner("test", None) is False
assert is_local_runner('test', None) is False
class TestGetRunnerInfo:
@@ -183,17 +184,17 @@ class TestGetRunnerInfo:
def test_returns_dict_with_expected_keys(self):
"""Should return dict with name, url, and category keys."""
info = get_runner_info("my-runner", "http://localhost:3000")
assert "name" in info
assert "url" in info
assert "category" in info
info = get_runner_info('my-runner', 'http://localhost:3000')
assert 'name' in info
assert 'url' in info
assert 'category' in info
def test_includes_correct_values(self):
"""Should include correct values in dict."""
info = get_runner_info("my-runner", "http://localhost:3000")
assert info["name"] == "my-runner"
assert info["url"] == "http://localhost:3000"
assert info["category"] == RunnerCategory.LOCAL
info = get_runner_info('my-runner', 'http://localhost:3000')
assert info['name'] == 'my-runner'
assert info['url'] == 'http://localhost:3000'
assert info['category'] == RunnerCategory.LOCAL
class TestExtractRunnerUrl:
@@ -203,74 +204,58 @@ class TestExtractRunnerUrl:
"""Should extract base-url from dify-service-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
url = extract_runner_url("dify-service-api", runner, pipeline_config)
assert url == "https://api.dify.ai"
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
url = extract_runner_url('dify-service-api', runner, pipeline_config)
assert url == 'https://api.dify.ai'
def test_n8n_service_api_extracts_url(self):
"""Should extract webhook-url from n8n-service-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"n8n-service-api": {"webhook-url": "https://my.n8n.cloud/webhook"}
}
}
url = extract_runner_url("n8n-service-api", runner, pipeline_config)
assert url == "https://my.n8n.cloud/webhook"
pipeline_config = {'ai': {'n8n-service-api': {'webhook-url': 'https://my.n8n.cloud/webhook'}}}
url = extract_runner_url('n8n-service-api', runner, pipeline_config)
assert url == 'https://my.n8n.cloud/webhook'
def test_coze_api_extracts_url(self):
"""Should extract api-base from coze-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"coze-api": {"api-base": "https://api.coze.com"}
}
}
url = extract_runner_url("coze-api", runner, pipeline_config)
assert url == "https://api.coze.com"
pipeline_config = {'ai': {'coze-api': {'api-base': 'https://api.coze.com'}}}
url = extract_runner_url('coze-api', runner, pipeline_config)
assert url == 'https://api.coze.com'
def test_langflow_api_extracts_url(self):
"""Should extract base-url from langflow-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"langflow-api": {"base-url": "https://cloud.langflow.ai"}
}
}
url = extract_runner_url("langflow-api", runner, pipeline_config)
assert url == "https://cloud.langflow.ai"
pipeline_config = {'ai': {'langflow-api': {'base-url': 'https://cloud.langflow.ai'}}}
url = extract_runner_url('langflow-api', runner, pipeline_config)
assert url == 'https://cloud.langflow.ai'
def test_unknown_runner_returns_none(self):
"""Unknown runner name should return None."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {}
url = extract_runner_url("unknown-runner", runner, pipeline_config)
url = extract_runner_url('unknown-runner', runner, pipeline_config)
assert url is None
def test_none_runner_returns_none(self):
"""None runner should return None."""
url = extract_runner_url("test", None, {})
url = extract_runner_url('test', None, {})
assert url is None
def test_runner_without_pipeline_config_returns_none(self):
"""Runner without pipeline_config attribute should return None."""
runner = Mock(spec=[]) # Empty spec means no attributes
url = extract_runner_url("test", runner, {})
url = extract_runner_url('test', runner, {})
assert url is None
def test_none_pipeline_config_returns_none(self):
"""None pipeline_config should return None."""
runner = Mock()
runner.pipeline_config = {}
url = extract_runner_url("dify-service-api", runner, None)
url = extract_runner_url('dify-service-api', runner, None)
assert url is None
def test_missing_ai_config_returns_none(self):
@@ -278,7 +263,7 @@ class TestExtractRunnerUrl:
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {}
url = extract_runner_url("dify-service-api", runner, pipeline_config)
url = extract_runner_url('dify-service-api', runner, pipeline_config)
assert url is None
@@ -289,19 +274,15 @@ class TestGetRunnerCategoryFromRunner:
"""Should extract URL and return correct category."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
category = get_runner_category_from_runner("dify-service-api", runner, pipeline_config)
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
category = get_runner_category_from_runner('dify-service-api', runner, pipeline_config)
assert category == RunnerCategory.CLOUD
def test_returns_unknown_for_missing_url(self):
"""Should return UNKNOWN when URL cannot be extracted."""
runner = Mock()
runner.pipeline_config = {}
category = get_runner_category_from_runner("unknown", runner, {})
category = get_runner_category_from_runner('unknown', runner, {})
assert category == RunnerCategory.UNKNOWN
@@ -310,9 +291,9 @@ class TestConstants:
def test_runner_category_constants(self):
"""RunnerCategory should have LOCAL, CLOUD, UNKNOWN."""
assert RunnerCategory.LOCAL == "local"
assert RunnerCategory.CLOUD == "cloud"
assert RunnerCategory.UNKNOWN == "unknown"
assert RunnerCategory.LOCAL == 'local'
assert RunnerCategory.CLOUD == 'cloud'
assert RunnerCategory.UNKNOWN == 'unknown'
def test_cloud_domains_not_empty(self):
"""CLOUD_DOMAINS should not be empty."""
@@ -323,5 +304,5 @@ class TestConstants:
assert len(LOCAL_PATTERNS) > 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -68,11 +68,7 @@ class TestNormalizeFilter:
def test_normalize_filter_multiple_conditions(self):
"""Multiple top-level keys are AND-ed (returned as multiple triples)."""
result = normalize_filter({
'file_id': 'abc',
'status': {'$ne': 'deleted'},
'created_at': {'$gte': 1700000000}
})
result = normalize_filter({'file_id': 'abc', 'status': {'$ne': 'deleted'}, 'created_at': {'$gte': 1700000000}})
assert len(result) == 3
# Order should match dict iteration order
@@ -149,11 +145,7 @@ class TestStripUnsupportedFields:
('file_id', '$eq', 'def'),
]
result = strip_unsupported_fields(
triples,
{'file_id', 'chunk_uuid'},
field_aliases={'uuid': 'chunk_uuid'}
)
result = strip_unsupported_fields(triples, {'file_id', 'chunk_uuid'}, field_aliases={'uuid': 'chunk_uuid'})
assert len(result) == 2
# 'uuid' should be resolved to 'chunk_uuid'
@@ -169,7 +161,7 @@ class TestStripUnsupportedFields:
result = strip_unsupported_fields(
triples,
{'file_id'}, # chunk_uuid not supported
field_aliases={'uuid': 'chunk_uuid'}
field_aliases={'uuid': 'chunk_uuid'},
)
assert result == []
@@ -207,4 +199,5 @@ class TestSupportedOpsConstant:
def test_supported_ops_is_frozenset(self):
"""SUPPORTED_OPS is a frozenset for immutability."""
from collections.abc import Set
assert isinstance(SUPPORTED_OPS, Set)
assert isinstance(SUPPORTED_OPS, Set)

View File

@@ -55,6 +55,7 @@ class TestVectorDBManagerInitialization:
# Run initialize synchronously for test
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
# Chroma should be instantiated
@@ -76,6 +77,7 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_chroma_class.assert_called_once_with(mock_app)
@@ -96,6 +98,7 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_qdrant_class.assert_called_once_with(mock_app)
@@ -115,6 +118,7 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_seekdb_class.assert_called_once_with(mock_app)
@@ -123,11 +127,7 @@ class TestVectorDBManagerInitialization:
"""Milvus config with custom URI."""
vdb_config = {
'use': 'milvus',
'milvus': {
'uri': 'http://localhost:19530',
'token': 'root:Milvus',
'db_name': 'langbot_db'
}
'milvus': {'uri': 'http://localhost:19530', 'token': 'root:Milvus', 'db_name': 'langbot_db'},
}
mock_app = self._create_mock_app(vdb_config)
@@ -141,13 +141,11 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_milvus_class.assert_called_once_with(
mock_app,
uri='http://localhost:19530',
token='root:Milvus',
db_name='langbot_db'
mock_app, uri='http://localhost:19530', token='root:Milvus', db_name='langbot_db'
)
def test_initialize_milvus_backend_defaults(self):
@@ -165,24 +163,15 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
# Should use default values
mock_milvus_class.assert_called_once_with(
mock_app,
uri='./data/milvus.db',
token=None,
db_name='default'
)
mock_milvus_class.assert_called_once_with(mock_app, uri='./data/milvus.db', token=None, db_name='default')
def test_initialize_pgvector_with_connection_string(self):
"""pgvector with connection string."""
vdb_config = {
'use': 'pgvector',
'pgvector': {
'connection_string': 'postgresql://user:pass@host:5432/langbot'
}
}
vdb_config = {'use': 'pgvector', 'pgvector': {'connection_string': 'postgresql://user:pass@host:5432/langbot'}}
mock_app = self._create_mock_app(vdb_config)
mocks = self._make_vector_import_mocks()
@@ -195,11 +184,11 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with(
mock_app,
connection_string='postgresql://user:pass@host:5432/langbot'
mock_app, connection_string='postgresql://user:pass@host:5432/langbot'
)
def test_initialize_pgvector_with_individual_params(self):
@@ -211,8 +200,8 @@ class TestVectorDBManagerInitialization:
'port': 5433,
'database': 'vectordb',
'user': 'admin',
'password': 'secret'
}
'password': 'secret',
},
}
mock_app = self._create_mock_app(vdb_config)
@@ -226,15 +215,11 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with(
mock_app,
host='db.example.com',
port=5433,
database='vectordb',
user='admin',
password='secret'
mock_app, host='db.example.com', port=5433, database='vectordb', user='admin', password='secret'
)
def test_initialize_pgvector_defaults(self):
@@ -252,15 +237,11 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_pgvector_class.assert_called_once_with(
mock_app,
host='localhost',
port=5432,
database='langbot',
user='postgres',
password='postgres'
mock_app, host='localhost', port=5432, database='langbot', user='postgres', password='postgres'
)
def test_initialize_unknown_backend_defaults_to_chroma(self):
@@ -278,6 +259,7 @@ class TestVectorDBManagerInitialization:
mgr = VectorDBManager(mock_app)
import asyncio
asyncio.get_event_loop().run_until_complete(mgr.initialize())
mock_chroma_class.assert_called_once_with(mock_app)
@@ -335,4 +317,4 @@ class TestVectorDBManagerProxies:
mgr.vector_db = mock_vector_db
result = mgr.get_supported_search_types()
assert result == ['vector', 'full_text']
assert result == ['vector', 'full_text']

View File

@@ -39,6 +39,7 @@ class TestVectorDatabaseAbstractMethods:
def test_abstract_methods_required(self):
"""Subclass must implement all abstract methods."""
class IncompleteVectorDB(VectorDatabase):
pass
@@ -47,11 +48,21 @@ class TestVectorDatabaseAbstractMethods:
def test_supported_search_types_default(self):
"""Default supported_search_types returns [VECTOR]."""
class MinimalVectorDB(VectorDatabase):
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
async def search(
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass
async def delete_by_file_id(self, collection, file_id):
@@ -71,11 +82,21 @@ class TestVectorDatabaseAbstractMethods:
def test_list_by_filter_default_implementation(self):
"""list_by_filter has default implementation returning empty."""
class MinimalVectorDB(VectorDatabase):
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
async def search(
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass
async def delete_by_file_id(self, collection, file_id):
@@ -93,9 +114,8 @@ class TestVectorDatabaseAbstractMethods:
db = MinimalVectorDB()
# list_by_filter should return empty list and -1 for total
import asyncio
result = asyncio.get_event_loop().run_until_complete(
db.list_by_filter('test_collection')
)
result = asyncio.get_event_loop().run_until_complete(db.list_by_filter('test_collection'))
assert result == ([], -1)
@@ -105,14 +125,17 @@ class TestVectorDatabaseInterface:
@pytest.fixture
def mock_vector_db(self):
"""Create a minimal mock VectorDatabase for testing."""
class MockVectorDB(VectorDatabase):
def __init__(self):
self.add_embeddings = AsyncMock()
self.search = AsyncMock(return_value={
'ids': [['id1', 'id2']],
'distances': [[0.1, 0.2]],
'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]]
})
self.search = AsyncMock(
return_value={
'ids': [['id1', 'id2']],
'distances': [[0.1, 0.2]],
'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]],
}
)
self.delete_by_file_id = AsyncMock()
self.delete_by_filter = AsyncMock(return_value=5)
self.get_or_create_collection = AsyncMock()
@@ -121,7 +144,16 @@ class TestVectorDatabaseInterface:
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
pass
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
async def search(
self,
collection,
query_embedding,
k=5,
search_type='vector',
query_text='',
filter=None,
vector_weight=None,
):
pass
async def delete_by_file_id(self, collection, file_id):
@@ -146,7 +178,7 @@ class TestVectorDatabaseInterface:
ids=['id1', 'id2'],
embeddings_list=[[0.1, 0.2], [0.3, 0.4]],
metadatas=[{'a': 1}, {'b': 2}],
documents=['doc1', 'doc2']
documents=['doc1', 'doc2'],
)
mock_vector_db.add_embeddings.assert_called_once()
@@ -162,7 +194,7 @@ class TestVectorDatabaseInterface:
search_type='hybrid',
query_text='search text',
filter={'file_id': 'abc'},
vector_weight=0.7
vector_weight=0.7,
)
mock_vector_db.search.assert_called_once()
@@ -170,4 +202,4 @@ class TestVectorDatabaseInterface:
async def test_delete_by_filter_returns_int(self, mock_vector_db):
"""delete_by_filter returns int count."""
result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'})
assert isinstance(result, int)
assert isinstance(result, int)

View File

@@ -5,6 +5,7 @@ Tests cover:
- _build_milvus_expr: Milvus boolean expression string conversion
- _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion
"""
from __future__ import annotations
from importlib import import_module
@@ -122,11 +123,13 @@ class TestQdrantFilterConversion:
"""Multiple conditions are combined in must/must_not."""
qdrant_module = get_qdrant_module()
result = qdrant_module._build_qdrant_filter({
'file_id': 'abc',
'status': {'$ne': 'deleted'},
'created_at': {'$gte': 100},
})
result = qdrant_module._build_qdrant_filter(
{
'file_id': 'abc',
'status': {'$ne': 'deleted'},
'created_at': {'$gte': 100},
}
)
assert len(result.must) == 2 # file_id eq + created_at gte
assert len(result.must_not) == 1 # status ne
@@ -198,10 +201,12 @@ class TestMilvusFilterConversion:
"""Multiple conditions are joined with 'and'."""
milvus_module = get_milvus_module()
result = milvus_module._build_milvus_expr({
'file_id': 'abc',
'chunk_uuid': {'$ne': 'def'},
})
result = milvus_module._build_milvus_expr(
{
'file_id': 'abc',
'chunk_uuid': {'$ne': 'def'},
}
)
assert 'and' in result
assert 'file_id == "abc"' in result
assert 'chunk_uuid != "def"' in result
@@ -272,6 +277,7 @@ class TestPgVectorFilterConversion:
assert len(result) == 1
# Verify it's a SQLAlchemy BinaryExpression
from sqlalchemy.sql.expression import BinaryExpression
assert isinstance(result[0], BinaryExpression)
def test_ne_operator_creates_inequality_condition(self):
@@ -321,10 +327,12 @@ class TestPgVectorFilterConversion:
"""Multiple conditions return list of conditions."""
pgvector_module = get_pgvector_module()
result = pgvector_module._build_pg_conditions({
'file_id': 'abc',
'chunk_uuid': {'$ne': 'def'},
})
result = pgvector_module._build_pg_conditions(
{
'file_id': 'abc',
'chunk_uuid': {'$ne': 'def'},
}
)
assert len(result) == 2
@@ -349,11 +357,13 @@ class TestPgVectorFilterConversion:
"""Only supported fields (text, file_id, chunk_uuid) are kept."""
pgvector_module = get_pgvector_module()
result = pgvector_module._build_pg_conditions({
'text': {'$ne': ''},
'file_id': 'abc',
'chunk_uuid': {'$in': ['x', 'y']},
'unsupported': 'value',
})
result = pgvector_module._build_pg_conditions(
{
'text': {'$ne': ''},
'file_id': 'abc',
'chunk_uuid': {'$in': ['x', 'y']},
'unsupported': 'value',
}
)
assert len(result) == 3 # Only supported fields
assert len(result) == 3 # Only supported fields

View File

@@ -1,3 +1,3 @@
"""
Test utilities package.
"""
"""

View File

@@ -26,6 +26,7 @@ from unittest.mock import MagicMock
class MockLifecycleControlScope(enum.Enum):
"""Mock enum for breaking circular import in core.entities."""
APPLICATION = 'application'
PLATFORM = 'platform'
PLUGIN = 'plugin'
@@ -190,4 +191,4 @@ def get_handler_modules_to_clear(handler_name: str) -> list[str]:
'langbot.pkg.pipeline.process.handler',
'langbot.pkg.pipeline.process.handlers',
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
]
]

2
web/.gitignore vendored
View File

@@ -12,6 +12,8 @@
# testing
/coverage
/playwright-report
/test-results
# next.js
/dist/

View File

@@ -1,3 +1,13 @@
# Debug LangBot Frontend
Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information.
## Tests
Run the frontend smoke tests without a backend process:
```bash
pnpm test:e2e
```
The Playwright suite starts Vite and mocks the LangBot backend and Space APIs.

View File

@@ -6,6 +6,7 @@
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview",
"test:e2e": "playwright test",
"lint": "eslint .",
"format": "prettier --write ."
},
@@ -86,6 +87,7 @@
"zod": "^3.24.4"
},
"devDependencies": {
"@playwright/test": "^1.61.0",
"@types/debug": "^4.1.12",
"@types/estree": "^1.0.8",
"@types/estree-jsx": "^1.0.5",

Some files were not shown because too many files have changed in this diff Show More