mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-16 18:56:02 +00:00
Compare commits
1 Commits
test/forma
...
codex/surv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5d71597f1 |
46
.github/workflows/frontend-tests.yml
vendored
46
.github/workflows/frontend-tests.yml
vendored
@@ -1,46 +0,0 @@
|
|||||||
name: Frontend Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
|
||||||
- 'web/**'
|
|
||||||
- '.github/workflows/frontend-tests.yml'
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
- develop
|
|
||||||
paths:
|
|
||||||
- 'web/**'
|
|
||||||
- '.github/workflows/frontend-tests.yml'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
playwright-smoke:
|
|
||||||
name: Playwright Smoke
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: '25'
|
|
||||||
|
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
version: 8.9.2
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
working-directory: web
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Install Playwright browsers
|
|
||||||
working-directory: web
|
|
||||||
run: pnpm exec playwright install --with-deps chromium
|
|
||||||
|
|
||||||
- name: Run Playwright smoke tests
|
|
||||||
working-directory: web
|
|
||||||
run: pnpm test:e2e
|
|
||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
|||||||
run: uv sync --dev
|
run: uv sync --dev
|
||||||
|
|
||||||
- name: Run ruff check
|
- name: Run ruff check
|
||||||
run: uv run ruff check src/langbot/ tests/ --output-format=concise
|
run: uv run ruff check src
|
||||||
|
|
||||||
- name: Run ruff format
|
- name: Run ruff format
|
||||||
run: uv run ruff format src --check
|
run: uv run ruff format src --check
|
||||||
|
|||||||
63
.github/workflows/run-tests.yml
vendored
63
.github/workflows/run-tests.yml
vendored
@@ -84,67 +84,6 @@ jobs:
|
|||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
echo "" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
e2e:
|
|
||||||
name: E2E Startup Tests
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: uv sync --dev
|
|
||||||
|
|
||||||
- name: Run E2E startup tests
|
|
||||||
run: uv run pytest tests/e2e -q --tb=short
|
|
||||||
|
|
||||||
- name: E2E Test Summary
|
|
||||||
if: always()
|
|
||||||
run: |
|
|
||||||
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
|
||||||
|
|
||||||
box-integration:
|
|
||||||
name: Box Integration Tests
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: uv sync --dev
|
|
||||||
|
|
||||||
- name: Check Docker runtime
|
|
||||||
run: docker info
|
|
||||||
|
|
||||||
- name: Run Box integration tests
|
|
||||||
run: uv run pytest tests/integration_tests -q --tb=short
|
|
||||||
|
|
||||||
- name: Box Integration Test Summary
|
|
||||||
if: always()
|
|
||||||
run: |
|
|
||||||
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
|
||||||
|
|
||||||
coverage:
|
coverage:
|
||||||
name: Coverage Gate
|
name: Coverage Gate
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -190,4 +129,4 @@ jobs:
|
|||||||
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
|
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
echo "" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
|
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
|
||||||
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
# LangBot Test Suite
|
# LangBot Test Suite
|
||||||
|
|
||||||
This directory contains the LangBot backend test suite, including unit tests,
|
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
|
||||||
integration tests, startup E2E tests, and container-backed Box runtime tests.
|
|
||||||
|
|
||||||
## Quality Gate Layers
|
## Quality Gate Layers
|
||||||
|
|
||||||
@@ -11,15 +10,10 @@ LangBot uses a layered quality gate system for developers and CI:
|
|||||||
|-------|---------|--------------|-------------|
|
|-------|---------|--------------|-------------|
|
||||||
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
|
||||||
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
|
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
|
||||||
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
|
|
||||||
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
|
|
||||||
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
|
|
||||||
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
|
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
|
||||||
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
|
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
|
||||||
|
|
||||||
**Note**: PostgreSQL migration tests and slow tests are NOT in local default
|
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
|
||||||
gates. They run in separate CI workflows. Frontend Playwright tests live under
|
|
||||||
`web/tests/e2e` and are documented in `web/README.md`.
|
|
||||||
|
|
||||||
### Developer Workflow
|
### Developer Workflow
|
||||||
|
|
||||||
@@ -34,9 +28,6 @@ make test-all-local
|
|||||||
bash scripts/test-quick.sh # ~2 min
|
bash scripts/test-quick.sh # ~2 min
|
||||||
bash scripts/test-integration-fast.sh # ~3 min
|
bash scripts/test-integration-fast.sh # ~3 min
|
||||||
bash scripts/test-coverage.sh # ~8 min
|
bash scripts/test-coverage.sh # ~8 min
|
||||||
uv run --python 3.12 pytest tests/e2e -q --tb=short
|
|
||||||
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
|
|
||||||
cd web && pnpm test:e2e
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Coverage Baseline
|
### Coverage Baseline
|
||||||
@@ -79,12 +70,6 @@ tests/
|
|||||||
│ └── persistence/ # Database/persistence tests
|
│ └── persistence/ # Database/persistence tests
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ └── test_migrations.py # Alembic migration tests
|
│ └── test_migrations.py # Alembic migration tests
|
||||||
├── e2e/ # Real LangBot startup E2E tests
|
|
||||||
│ ├── conftest.py
|
|
||||||
│ ├── test_startup.py
|
|
||||||
│ └── utils/
|
|
||||||
├── integration_tests/ # Container-backed integration tests
|
|
||||||
│ └── box/ # Box runtime and MCP process tests
|
|
||||||
├── smoke/ # Smoke tests (quick validation)
|
├── smoke/ # Smoke tests (quick validation)
|
||||||
│ └── test_fake_message_flow.py
|
│ └── test_fake_message_flow.py
|
||||||
├── unit_tests/ # Unit tests
|
├── unit_tests/ # Unit tests
|
||||||
@@ -318,44 +303,6 @@ These tests:
|
|||||||
- Test prevent_default, exception handling, and full message flow
|
- Test prevent_default, exception handling, and full message flow
|
||||||
- Do not require real LLM provider keys
|
- Do not require real LLM provider keys
|
||||||
|
|
||||||
### Running backend E2E startup tests
|
|
||||||
|
|
||||||
Backend E2E tests start a real LangBot process with a generated minimal
|
|
||||||
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
|
|
||||||
They do not require provider keys or external services.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --python 3.12 pytest tests/e2e -q --tb=short
|
|
||||||
```
|
|
||||||
|
|
||||||
These tests verify startup orchestration, migrations, API route registration,
|
|
||||||
and the minimal no-LLM startup path. The E2E process manager disables ambient
|
|
||||||
proxy variables for subprocess startup and uses direct localhost HTTP clients,
|
|
||||||
so local proxy settings should not affect the health checks.
|
|
||||||
|
|
||||||
### Running Box integration tests
|
|
||||||
|
|
||||||
Box integration tests exercise the real sandbox runtime path, including command
|
|
||||||
execution, session persistence, managed process WebSocket attachment, and
|
|
||||||
cleanup behavior.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
|
|
||||||
```
|
|
||||||
|
|
||||||
These tests require a working Docker or Podman runtime. In CI, the dedicated
|
|
||||||
Box integration job checks Docker availability before running the tests.
|
|
||||||
|
|
||||||
### Running frontend E2E tests
|
|
||||||
|
|
||||||
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
|
|
||||||
and mock the LangBot backend and Space APIs, so no backend process is required.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd web
|
|
||||||
pnpm test:e2e
|
|
||||||
```
|
|
||||||
|
|
||||||
### Known Issues
|
### Known Issues
|
||||||
|
|
||||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||||
@@ -373,9 +320,6 @@ Tests are automatically run on:
|
|||||||
- Push to master/develop branches
|
- Push to master/develop branches
|
||||||
|
|
||||||
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
|
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
|
||||||
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
|
|
||||||
they exercise process/container behavior instead of pure Python compatibility.
|
|
||||||
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
|
|
||||||
|
|
||||||
## Adding New Tests
|
## Adding New Tests
|
||||||
|
|
||||||
@@ -462,4 +406,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
|
|||||||
- [ ] Add E2E tests
|
- [ ] Add E2E tests
|
||||||
- [ ] Add performance benchmarks
|
- [ ] Add performance benchmarks
|
||||||
- [ ] Add mutation testing for better coverage quality
|
- [ ] Add mutation testing for better coverage quality
|
||||||
- [ ] Add property-based testing with Hypothesis
|
- [ ] Add property-based testing with Hypothesis
|
||||||
@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
|
|||||||
|
|
||||||
base_url = f'http://127.0.0.1:{e2e_port}'
|
base_url = f'http://127.0.0.1:{e2e_port}'
|
||||||
|
|
||||||
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client:
|
with httpx.Client(base_url=base_url, timeout=10.0) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def e2e_db_path(e2e_tmpdir):
|
def e2e_db_path(e2e_tmpdir):
|
||||||
"""Path to SQLite database file."""
|
"""Path to SQLite database file."""
|
||||||
return e2e_tmpdir / 'data' / 'langbot.db'
|
return e2e_tmpdir / 'data' / 'langbot.db'
|
||||||
@@ -38,13 +38,12 @@ class TestStartupFlow:
|
|||||||
# System info should contain version info
|
# System info should contain version info
|
||||||
assert 'version' in data['data'] or 'edition' in data['data']
|
assert 'version' in data['data'] or 'edition' in data['data']
|
||||||
|
|
||||||
def test_database_initialized(self, langbot_process, e2e_db_path):
|
def test_database_initialized(self, e2e_db_path):
|
||||||
"""Verify SQLite database was created and initialized."""
|
"""Verify SQLite database was created and initialized."""
|
||||||
assert e2e_db_path.exists()
|
assert e2e_db_path.exists()
|
||||||
|
|
||||||
# Database should have some tables after migration
|
# Database should have some tables after migration
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
conn = sqlite3.connect(str(e2e_db_path))
|
conn = sqlite3.connect(str(e2e_db_path))
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -75,13 +74,10 @@ class TestStartupFlow:
|
|||||||
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
|
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
|
||||||
"""Test auth endpoint."""
|
"""Test auth endpoint."""
|
||||||
# First startup may allow initial setup
|
# First startup may allow initial setup
|
||||||
response = e2e_client.post(
|
response = e2e_client.post('/api/v1/user/auth', json={
|
||||||
'/api/v1/user/auth',
|
'username': 'admin',
|
||||||
json={
|
'password': 'admin',
|
||||||
'user': 'admin',
|
})
|
||||||
'password': 'admin',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Response could be:
|
# Response could be:
|
||||||
# - 200 if auth succeeds
|
# - 200 if auth succeeds
|
||||||
@@ -98,10 +94,9 @@ class TestStartupStages:
|
|||||||
# If API responds on e2e_port, config was loaded
|
# If API responds on e2e_port, config was loaded
|
||||||
assert e2e_client.get('/api/v1/system/info').status_code == 200
|
assert e2e_client.get('/api/v1/system/info').status_code == 200
|
||||||
|
|
||||||
def test_migrations_applied(self, langbot_process, e2e_db_path):
|
def test_migrations_applied(self, e2e_db_path):
|
||||||
"""Verify database migrations were applied."""
|
"""Verify database migrations were applied."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
conn = sqlite3.connect(str(e2e_db_path))
|
conn = sqlite3.connect(str(e2e_db_path))
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
|||||||
@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
|
|||||||
for path in directories.values():
|
for path in directories.values():
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
return directories
|
return directories
|
||||||
@@ -44,17 +44,6 @@ class LangBotProcess:
|
|||||||
# Prepare environment
|
# Prepare environment
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env['PYTHONPATH'] = str(self.project_root / 'src')
|
env['PYTHONPATH'] = str(self.project_root / 'src')
|
||||||
for proxy_key in (
|
|
||||||
'HTTP_PROXY',
|
|
||||||
'HTTPS_PROXY',
|
|
||||||
'ALL_PROXY',
|
|
||||||
'http_proxy',
|
|
||||||
'https_proxy',
|
|
||||||
'all_proxy',
|
|
||||||
):
|
|
||||||
env.pop(proxy_key, None)
|
|
||||||
env['NO_PROXY'] = '127.0.0.1,localhost'
|
|
||||||
env['no_proxy'] = '127.0.0.1,localhost'
|
|
||||||
|
|
||||||
# Set API port via environment variable
|
# Set API port via environment variable
|
||||||
env['API__PORT'] = str(self.port)
|
env['API__PORT'] = str(self.port)
|
||||||
@@ -90,11 +79,9 @@ precision = 2
|
|||||||
f.write(coveragerc_content)
|
f.write(coveragerc_content)
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
'coverage',
|
'coverage', 'run',
|
||||||
'run',
|
|
||||||
'--rcfile=' + str(coveragerc_path),
|
'--rcfile=' + str(coveragerc_path),
|
||||||
'-m',
|
'-m', 'langbot',
|
||||||
'langbot',
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
cmd = ['uv', 'run', 'python', '-m', 'langbot']
|
cmd = ['uv', 'run', 'python', '-m', 'langbot']
|
||||||
@@ -126,8 +113,6 @@ precision = 2
|
|||||||
r = httpx.get(
|
r = httpx.get(
|
||||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||||
timeout=2.0,
|
timeout=2.0,
|
||||||
follow_redirects=False,
|
|
||||||
trust_env=False,
|
|
||||||
)
|
)
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
logger.info(f'LangBot started successfully on port {self.port}')
|
logger.info(f'LangBot started successfully on port {self.port}')
|
||||||
@@ -200,8 +185,6 @@ precision = 2
|
|||||||
r = httpx.get(
|
r = httpx.get(
|
||||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||||
timeout=5.0,
|
timeout=5.0,
|
||||||
follow_redirects=False,
|
|
||||||
trust_env=False,
|
|
||||||
)
|
)
|
||||||
return r.status_code == 200
|
return r.status_code == 200
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -218,4 +201,4 @@ def find_project_root() -> Path:
|
|||||||
return parent
|
return parent
|
||||||
|
|
||||||
# Fallback to LangBot-test-build directory
|
# Fallback to LangBot-test-build directory
|
||||||
return Path('/home/glwuy/langbot-app/LangBot-test-build')
|
return Path('/home/glwuy/langbot-app/LangBot-test-build')
|
||||||
@@ -58,45 +58,45 @@ from tests.factories.platform import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# App
|
# App
|
||||||
'FakeApp',
|
"FakeApp",
|
||||||
'fake_app',
|
"fake_app",
|
||||||
# Message chains
|
# Message chains
|
||||||
'text_chain',
|
"text_chain",
|
||||||
'group_text_chain',
|
"group_text_chain",
|
||||||
'mention_chain',
|
"mention_chain",
|
||||||
'image_chain',
|
"image_chain",
|
||||||
# Message events
|
# Message events
|
||||||
'friend_message_event',
|
"friend_message_event",
|
||||||
'group_message_event',
|
"group_message_event",
|
||||||
# Mock adapters
|
# Mock adapters
|
||||||
'mock_adapter',
|
"mock_adapter",
|
||||||
# Queries
|
# Queries
|
||||||
'text_query',
|
"text_query",
|
||||||
'group_text_query',
|
"group_text_query",
|
||||||
'private_text_query',
|
"private_text_query",
|
||||||
'command_query',
|
"command_query",
|
||||||
'mention_query',
|
"mention_query",
|
||||||
'empty_query',
|
"empty_query",
|
||||||
'image_query',
|
"image_query",
|
||||||
'file_query',
|
"file_query",
|
||||||
'unsupported_query',
|
"unsupported_query",
|
||||||
'voice_query',
|
"voice_query",
|
||||||
'at_all_query',
|
"at_all_query",
|
||||||
'query_with_session',
|
"query_with_session",
|
||||||
'query_with_config',
|
"query_with_config",
|
||||||
# Provider
|
# Provider
|
||||||
'FakeProvider',
|
"FakeProvider",
|
||||||
'fake_provider',
|
"fake_provider",
|
||||||
'fake_provider_pong',
|
"fake_provider_pong",
|
||||||
'fake_provider_timeout',
|
"fake_provider_timeout",
|
||||||
'fake_provider_auth_error',
|
"fake_provider_auth_error",
|
||||||
'fake_provider_rate_limit',
|
"fake_provider_rate_limit",
|
||||||
'fake_provider_malformed',
|
"fake_provider_malformed",
|
||||||
'fake_model',
|
"fake_model",
|
||||||
# Platform
|
# Platform
|
||||||
'FakePlatform',
|
"FakePlatform",
|
||||||
'fake_platform',
|
"fake_platform",
|
||||||
'fake_platform_with_streaming',
|
"fake_platform_with_streaming",
|
||||||
'fake_platform_with_failure',
|
"fake_platform_with_failure",
|
||||||
'mock_platform_adapter',
|
"mock_platform_adapter",
|
||||||
]
|
]
|
||||||
@@ -30,36 +30,32 @@ def _next_query_id() -> int:
|
|||||||
# ============== Message Chain Factories ==============
|
# ============== Message Chain Factories ==============
|
||||||
|
|
||||||
|
|
||||||
def text_chain(text: str = 'hello') -> platform_message.MessageChain:
|
def text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||||
"""Create a simple text message chain."""
|
"""Create a simple text message chain."""
|
||||||
return platform_message.MessageChain(
|
return platform_message.MessageChain([
|
||||||
[
|
platform_message.Plain(text=text),
|
||||||
platform_message.Plain(text=text),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain:
|
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||||
"""Create a group text message chain (same as text_chain, context provided by event)."""
|
"""Create a group text message chain (same as text_chain, context provided by event)."""
|
||||||
return text_chain(text)
|
return text_chain(text)
|
||||||
|
|
||||||
|
|
||||||
def mention_chain(
|
def mention_chain(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
target: typing.Union[int, str] = 12345,
|
target: typing.Union[int, str] = 12345,
|
||||||
) -> platform_message.MessageChain:
|
) -> platform_message.MessageChain:
|
||||||
"""Create a message chain with @mention."""
|
"""Create a message chain with @mention."""
|
||||||
return platform_message.MessageChain(
|
return platform_message.MessageChain([
|
||||||
[
|
platform_message.At(target=target),
|
||||||
platform_message.At(target=target),
|
platform_message.Plain(text=f" {text}"),
|
||||||
platform_message.Plain(text=f' {text}'),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def image_chain(
|
def image_chain(
|
||||||
text: str = '',
|
text: str = "",
|
||||||
url: str = 'https://example.com/image.png',
|
url: str = "https://example.com/image.png",
|
||||||
) -> platform_message.MessageChain:
|
) -> platform_message.MessageChain:
|
||||||
"""Create a message chain with an image."""
|
"""Create a message chain with an image."""
|
||||||
components = []
|
components = []
|
||||||
@@ -70,15 +66,13 @@ def image_chain(
|
|||||||
|
|
||||||
|
|
||||||
def command_chain(
|
def command_chain(
|
||||||
command: str = 'help',
|
command: str = "help",
|
||||||
prefix: str = '/',
|
prefix: str = "/",
|
||||||
) -> platform_message.MessageChain:
|
) -> platform_message.MessageChain:
|
||||||
"""Create a command message chain."""
|
"""Create a command message chain."""
|
||||||
return platform_message.MessageChain(
|
return platform_message.MessageChain([
|
||||||
[
|
platform_message.Plain(text=f"{prefix}{command}"),
|
||||||
platform_message.Plain(text=f'{prefix}{command}'),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============== Message Event Factories ==============
|
# ============== Message Event Factories ==============
|
||||||
@@ -87,7 +81,7 @@ def command_chain(
|
|||||||
def friend_message_event(
|
def friend_message_event(
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
nickname: str = 'TestUser',
|
nickname: str = "TestUser",
|
||||||
) -> platform_events.FriendMessage:
|
) -> platform_events.FriendMessage:
|
||||||
"""Create a friend (private) message event."""
|
"""Create a friend (private) message event."""
|
||||||
sender = platform_entities.Friend(
|
sender = platform_entities.Friend(
|
||||||
@@ -96,7 +90,7 @@ def friend_message_event(
|
|||||||
remark=None,
|
remark=None,
|
||||||
)
|
)
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
type='FriendMessage',
|
type="FriendMessage",
|
||||||
sender=sender,
|
sender=sender,
|
||||||
message_chain=message_chain,
|
message_chain=message_chain,
|
||||||
time=1609459200,
|
time=1609459200,
|
||||||
@@ -106,9 +100,9 @@ def friend_message_event(
|
|||||||
def group_message_event(
|
def group_message_event(
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
sender_name: str = 'TestUser',
|
sender_name: str = "TestUser",
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
group_name: str = 'TestGroup',
|
group_name: str = "TestGroup",
|
||||||
) -> platform_events.GroupMessage:
|
) -> platform_events.GroupMessage:
|
||||||
"""Create a group message event."""
|
"""Create a group message event."""
|
||||||
group = platform_entities.Group(
|
group = platform_entities.Group(
|
||||||
@@ -123,7 +117,7 @@ def group_message_event(
|
|||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
return platform_events.GroupMessage(
|
return platform_events.GroupMessage(
|
||||||
type='GroupMessage',
|
type="GroupMessage",
|
||||||
sender=sender,
|
sender=sender,
|
||||||
message_chain=message_chain,
|
message_chain=message_chain,
|
||||||
time=1609459200,
|
time=1609459200,
|
||||||
@@ -158,36 +152,36 @@ def _base_query(
|
|||||||
query_id = _next_query_id()
|
query_id = _next_query_id()
|
||||||
|
|
||||||
base_data = {
|
base_data = {
|
||||||
'query_id': query_id,
|
"query_id": query_id,
|
||||||
'launcher_type': launcher_type,
|
"launcher_type": launcher_type,
|
||||||
'launcher_id': launcher_id,
|
"launcher_id": launcher_id,
|
||||||
'sender_id': sender_id,
|
"sender_id": sender_id,
|
||||||
'message_chain': message_chain,
|
"message_chain": message_chain,
|
||||||
'message_event': message_event,
|
"message_event": message_event,
|
||||||
'adapter': adapter,
|
"adapter": adapter,
|
||||||
'pipeline_uuid': 'test-pipeline-uuid',
|
"pipeline_uuid": "test-pipeline-uuid",
|
||||||
'bot_uuid': 'test-bot-uuid',
|
"bot_uuid": "test-bot-uuid",
|
||||||
'pipeline_config': {
|
"pipeline_config": {
|
||||||
'ai': {
|
"ai": {
|
||||||
'runner': {'runner': 'local-agent'},
|
"runner": {"runner": "local-agent"},
|
||||||
'local-agent': {
|
"local-agent": {
|
||||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||||
'prompt': 'test-prompt',
|
"prompt": "test-prompt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
"trigger": {"misc": {"combine-quote-message": False}},
|
||||||
},
|
},
|
||||||
'session': None,
|
"session": None,
|
||||||
'prompt': None,
|
"prompt": None,
|
||||||
'messages': [],
|
"messages": [],
|
||||||
'user_message': None,
|
"user_message": None,
|
||||||
'use_funcs': [],
|
"use_funcs": [],
|
||||||
'use_llm_model_uuid': None,
|
"use_llm_model_uuid": None,
|
||||||
'variables': {},
|
"variables": {},
|
||||||
'resp_messages': [],
|
"resp_messages": [],
|
||||||
'resp_message_chain': None,
|
"resp_message_chain": None,
|
||||||
'current_stage_name': None,
|
"current_stage_name": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply overrides
|
# Apply overrides
|
||||||
@@ -198,7 +192,7 @@ def _base_query(
|
|||||||
|
|
||||||
|
|
||||||
def text_query(
|
def text_query(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -218,7 +212,7 @@ def text_query(
|
|||||||
|
|
||||||
|
|
||||||
def private_text_query(
|
def private_text_query(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -227,7 +221,7 @@ def private_text_query(
|
|||||||
|
|
||||||
|
|
||||||
def group_text_query(
|
def group_text_query(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
**overrides,
|
**overrides,
|
||||||
@@ -248,8 +242,8 @@ def group_text_query(
|
|||||||
|
|
||||||
|
|
||||||
def command_query(
|
def command_query(
|
||||||
command: str = 'help',
|
command: str = "help",
|
||||||
prefix: str = '/',
|
prefix: str = "/",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -269,7 +263,7 @@ def command_query(
|
|||||||
|
|
||||||
|
|
||||||
def mention_query(
|
def mention_query(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
target: typing.Union[int, str] = 12345,
|
target: typing.Union[int, str] = 12345,
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
@@ -307,8 +301,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
|
|||||||
|
|
||||||
|
|
||||||
def image_query(
|
def image_query(
|
||||||
text: str = '',
|
text: str = "",
|
||||||
url: str = 'https://example.com/image.png',
|
url: str = "https://example.com/image.png",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -328,9 +322,9 @@ def image_query(
|
|||||||
|
|
||||||
|
|
||||||
def file_query(
|
def file_query(
|
||||||
url: str = 'https://example.com/document.pdf',
|
url: str = "https://example.com/document.pdf",
|
||||||
name: str = 'document.pdf',
|
name: str = "document.pdf",
|
||||||
text: str = '',
|
text: str = "",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -354,8 +348,8 @@ def file_query(
|
|||||||
|
|
||||||
|
|
||||||
def unsupported_query(
|
def unsupported_query(
|
||||||
unsupported_type: str = 'CustomComponent',
|
unsupported_type: str = "CustomComponent",
|
||||||
text: str = '',
|
text: str = "",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -364,7 +358,7 @@ def unsupported_query(
|
|||||||
if text:
|
if text:
|
||||||
components.append(platform_message.Plain(text=text))
|
components.append(platform_message.Plain(text=text))
|
||||||
# Use Unknown component for unsupported types
|
# Use Unknown component for unsupported types
|
||||||
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}'))
|
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
|
||||||
chain = platform_message.MessageChain(components)
|
chain = platform_message.MessageChain(components)
|
||||||
event = friend_message_event(chain, sender_id)
|
event = friend_message_event(chain, sender_id)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
@@ -380,7 +374,7 @@ def unsupported_query(
|
|||||||
|
|
||||||
|
|
||||||
def query_with_session(
|
def query_with_session(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
session: provider_session.Session = None,
|
session: provider_session.Session = None,
|
||||||
**overrides,
|
**overrides,
|
||||||
@@ -395,7 +389,7 @@ def query_with_session(
|
|||||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||||
launcher_id=sender_id,
|
launcher_id=sender_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
use_prompt_name='default',
|
use_prompt_name="default",
|
||||||
using_conversation=None,
|
using_conversation=None,
|
||||||
conversations=[],
|
conversations=[],
|
||||||
)
|
)
|
||||||
@@ -404,7 +398,7 @@ def query_with_session(
|
|||||||
|
|
||||||
|
|
||||||
def query_with_config(
|
def query_with_config(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
pipeline_config: dict = None,
|
pipeline_config: dict = None,
|
||||||
**overrides,
|
**overrides,
|
||||||
@@ -416,22 +410,22 @@ def query_with_config(
|
|||||||
"""
|
"""
|
||||||
if pipeline_config is None:
|
if pipeline_config is None:
|
||||||
pipeline_config = {
|
pipeline_config = {
|
||||||
'ai': {
|
"ai": {
|
||||||
'runner': {'runner': 'local-agent'},
|
"runner": {"runner": "local-agent"},
|
||||||
'local-agent': {
|
"local-agent": {
|
||||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||||
'prompt': 'test-prompt',
|
"prompt": "test-prompt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||||
'trigger': {'misc': {'combine-quote-message': False}},
|
"trigger": {"misc": {"combine-quote-message": False}},
|
||||||
}
|
}
|
||||||
|
|
||||||
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
|
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
|
||||||
|
|
||||||
|
|
||||||
def voice_query(
|
def voice_query(
|
||||||
url: str = 'https://example.com/audio.mp3',
|
url: str = "https://example.com/audio.mp3",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
**overrides,
|
**overrides,
|
||||||
) -> pipeline_query.Query:
|
) -> pipeline_query.Query:
|
||||||
@@ -454,7 +448,7 @@ def voice_query(
|
|||||||
|
|
||||||
|
|
||||||
def at_all_query(
|
def at_all_query(
|
||||||
text: str = 'hello',
|
text: str = "hello",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
**overrides,
|
**overrides,
|
||||||
@@ -462,7 +456,7 @@ def at_all_query(
|
|||||||
"""Create a group query with @All mention."""
|
"""Create a group query with @All mention."""
|
||||||
components = [
|
components = [
|
||||||
platform_message.AtAll(),
|
platform_message.AtAll(),
|
||||||
platform_message.Plain(text=f' {text}'),
|
platform_message.Plain(text=f" {text}"),
|
||||||
]
|
]
|
||||||
chain = platform_message.MessageChain(components)
|
chain = platform_message.MessageChain(components)
|
||||||
event = group_message_event(chain, sender_id, group_id=group_id)
|
event = group_message_event(chain, sender_id, group_id=group_id)
|
||||||
@@ -475,4 +469,4 @@ def at_all_query(
|
|||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
**overrides,
|
**overrides,
|
||||||
)
|
)
|
||||||
@@ -33,7 +33,7 @@ class FakePlatform:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
bot_account_id: str = 'test-bot',
|
bot_account_id: str = "test-bot",
|
||||||
stream_output_supported: bool = False,
|
stream_output_supported: bool = False,
|
||||||
raise_error: Exception = None,
|
raise_error: Exception = None,
|
||||||
):
|
):
|
||||||
@@ -48,16 +48,16 @@ class FakePlatform:
|
|||||||
# Registered listeners
|
# Registered listeners
|
||||||
self._listeners: dict = {}
|
self._listeners: dict = {}
|
||||||
|
|
||||||
def raises(self, error: Exception) -> 'FakePlatform':
|
def raises(self, error: Exception) -> "FakePlatform":
|
||||||
"""Configure platform to raise an error on send."""
|
"""Configure platform to raise an error on send."""
|
||||||
self._raise_error = error
|
self._raise_error = error
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def send_failure(self) -> 'FakePlatform':
|
def send_failure(self) -> "FakePlatform":
|
||||||
"""Configure platform to simulate send failure."""
|
"""Configure platform to simulate send failure."""
|
||||||
return self.raises(Exception('Platform send failure'))
|
return self.raises(Exception("Platform send failure"))
|
||||||
|
|
||||||
def supports_streaming(self, supported: bool = True) -> 'FakePlatform':
|
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
|
||||||
"""Configure whether streaming output is supported."""
|
"""Configure whether streaming output is supported."""
|
||||||
self._stream_output_supported = supported
|
self._stream_output_supported = supported
|
||||||
return self
|
return self
|
||||||
@@ -89,7 +89,7 @@ class FakePlatform:
|
|||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
nickname: str = 'TestUser',
|
nickname: str = "TestUser",
|
||||||
) -> platform_events.FriendMessage:
|
) -> platform_events.FriendMessage:
|
||||||
"""Create an inbound friend (private) message event."""
|
"""Create an inbound friend (private) message event."""
|
||||||
sender = platform_entities.Friend(
|
sender = platform_entities.Friend(
|
||||||
@@ -97,13 +97,11 @@ class FakePlatform:
|
|||||||
nickname=nickname,
|
nickname=nickname,
|
||||||
remark=None,
|
remark=None,
|
||||||
)
|
)
|
||||||
chain = platform_message.MessageChain(
|
chain = platform_message.MessageChain([
|
||||||
[
|
platform_message.Plain(text=text),
|
||||||
platform_message.Plain(text=text),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
type='FriendMessage',
|
type="FriendMessage",
|
||||||
sender=sender,
|
sender=sender,
|
||||||
message_chain=chain,
|
message_chain=chain,
|
||||||
time=1609459200,
|
time=1609459200,
|
||||||
@@ -113,9 +111,9 @@ class FakePlatform:
|
|||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
sender_name: str = 'TestUser',
|
sender_name: str = "TestUser",
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
group_name: str = 'TestGroup',
|
group_name: str = "TestGroup",
|
||||||
mention_bot: bool = False,
|
mention_bot: bool = False,
|
||||||
) -> platform_events.GroupMessage:
|
) -> platform_events.GroupMessage:
|
||||||
"""Create an inbound group message event.
|
"""Create an inbound group message event.
|
||||||
@@ -144,12 +142,12 @@ class FakePlatform:
|
|||||||
components = []
|
components = []
|
||||||
if mention_bot:
|
if mention_bot:
|
||||||
components.append(platform_message.At(target=self.bot_account_id))
|
components.append(platform_message.At(target=self.bot_account_id))
|
||||||
components.append(platform_message.Plain(text=' '))
|
components.append(platform_message.Plain(text=" "))
|
||||||
components.append(platform_message.Plain(text=text))
|
components.append(platform_message.Plain(text=text))
|
||||||
|
|
||||||
chain = platform_message.MessageChain(components)
|
chain = platform_message.MessageChain(components)
|
||||||
return platform_events.GroupMessage(
|
return platform_events.GroupMessage(
|
||||||
type='GroupMessage',
|
type="GroupMessage",
|
||||||
sender=sender,
|
sender=sender,
|
||||||
message_chain=chain,
|
message_chain=chain,
|
||||||
time=1609459200,
|
time=1609459200,
|
||||||
@@ -157,8 +155,8 @@ class FakePlatform:
|
|||||||
|
|
||||||
def create_image_message(
|
def create_image_message(
|
||||||
self,
|
self,
|
||||||
url: str = 'https://example.com/image.png',
|
url: str = "https://example.com/image.png",
|
||||||
text: str = '',
|
text: str = "",
|
||||||
sender_id: typing.Union[int, str] = 12345,
|
sender_id: typing.Union[int, str] = 12345,
|
||||||
is_group: bool = False,
|
is_group: bool = False,
|
||||||
group_id: typing.Union[int, str] = 99999,
|
group_id: typing.Union[int, str] = 99999,
|
||||||
@@ -171,12 +169,12 @@ class FakePlatform:
|
|||||||
chain = platform_message.MessageChain(components)
|
chain = platform_message.MessageChain(components)
|
||||||
|
|
||||||
if is_group:
|
if is_group:
|
||||||
return self.create_group_message('', sender_id, group_id=group_id)
|
return self.create_group_message("", sender_id, group_id=group_id)
|
||||||
# Replace chain
|
# Replace chain
|
||||||
else:
|
else:
|
||||||
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None)
|
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
|
||||||
return platform_events.FriendMessage(
|
return platform_events.FriendMessage(
|
||||||
type='FriendMessage',
|
type="FriendMessage",
|
||||||
sender=sender,
|
sender=sender,
|
||||||
message_chain=chain,
|
message_chain=chain,
|
||||||
time=1609459200,
|
time=1609459200,
|
||||||
@@ -194,14 +192,12 @@ class FakePlatform:
|
|||||||
if self._raise_error:
|
if self._raise_error:
|
||||||
raise self._raise_error
|
raise self._raise_error
|
||||||
|
|
||||||
self._outbound_messages.append(
|
self._outbound_messages.append({
|
||||||
{
|
"type": "send",
|
||||||
'type': 'send',
|
"target_type": target_type,
|
||||||
'target_type': target_type,
|
"target_id": target_id,
|
||||||
'target_id': target_id,
|
"message": message,
|
||||||
'message': message,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -213,15 +209,13 @@ class FakePlatform:
|
|||||||
if self._raise_error:
|
if self._raise_error:
|
||||||
raise self._raise_error
|
raise self._raise_error
|
||||||
|
|
||||||
self._outbound_messages.append(
|
self._outbound_messages.append({
|
||||||
{
|
"type": "reply",
|
||||||
'type': 'reply',
|
"source_type": message_source.type,
|
||||||
'source_type': message_source.type,
|
"source": message_source,
|
||||||
'source': message_source,
|
"message": message,
|
||||||
'message': message,
|
"quote_origin": quote_origin,
|
||||||
'quote_origin': quote_origin,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def reply_message_chunk(
|
async def reply_message_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -235,17 +229,15 @@ class FakePlatform:
|
|||||||
if self._raise_error:
|
if self._raise_error:
|
||||||
raise self._raise_error
|
raise self._raise_error
|
||||||
|
|
||||||
self._outbound_chunks.append(
|
self._outbound_chunks.append({
|
||||||
{
|
"type": "reply_chunk",
|
||||||
'type': 'reply_chunk',
|
"source_type": message_source.type,
|
||||||
'source_type': message_source.type,
|
"source": message_source,
|
||||||
'source': message_source,
|
"bot_message": bot_message,
|
||||||
'bot_message': bot_message,
|
"message": message,
|
||||||
'message': message,
|
"quote_origin": quote_origin,
|
||||||
'quote_origin': quote_origin,
|
"is_final": is_final,
|
||||||
'is_final': is_final,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def is_stream_output_supported(self) -> bool:
|
async def is_stream_output_supported(self) -> bool:
|
||||||
"""Return whether streaming output is supported."""
|
"""Return whether streaming output is supported."""
|
||||||
@@ -303,7 +295,7 @@ class FakePlatform:
|
|||||||
|
|
||||||
|
|
||||||
def fake_platform(
|
def fake_platform(
|
||||||
bot_account_id: str = 'test-bot',
|
bot_account_id: str = "test-bot",
|
||||||
stream_output_supported: bool = False,
|
stream_output_supported: bool = False,
|
||||||
) -> FakePlatform:
|
) -> FakePlatform:
|
||||||
"""Create a FakePlatform instance."""
|
"""Create a FakePlatform instance."""
|
||||||
@@ -336,7 +328,9 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
|
|||||||
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
|
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
|
||||||
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
|
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
|
||||||
adapter.send_message = AsyncMock(side_effect=platform.send_message)
|
adapter.send_message = AsyncMock(side_effect=platform.send_message)
|
||||||
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported)
|
adapter.is_stream_output_supported = AsyncMock(
|
||||||
|
return_value=platform._stream_output_supported
|
||||||
|
)
|
||||||
adapter._fake_platform = platform # Store for assertions
|
adapter._fake_platform = platform # Store for assertions
|
||||||
|
|
||||||
return adapter
|
return adapter
|
||||||
@@ -27,51 +27,51 @@ class FakeProvider:
|
|||||||
Does not require API keys.
|
Does not require API keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
|
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
default_response: str = 'fake response',
|
default_response: str = "fake response",
|
||||||
streaming_chunks: list[str] = None,
|
streaming_chunks: list[str] = None,
|
||||||
raise_error: Exception = None,
|
raise_error: Exception = None,
|
||||||
captured_requests: list = None,
|
captured_requests: list = None,
|
||||||
):
|
):
|
||||||
self._default_response = default_response
|
self._default_response = default_response
|
||||||
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
|
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
|
||||||
self._raise_error = raise_error
|
self._raise_error = raise_error
|
||||||
self._captured_requests = captured_requests if captured_requests is not None else []
|
self._captured_requests = captured_requests if captured_requests is not None else []
|
||||||
|
|
||||||
def returns(self, text: str) -> 'FakeProvider':
|
def returns(self, text: str) -> "FakeProvider":
|
||||||
"""Configure provider to return a specific text response."""
|
"""Configure provider to return a specific text response."""
|
||||||
self._default_response = text
|
self._default_response = text
|
||||||
self._streaming_chunks = [text]
|
self._streaming_chunks = [text]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
|
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
|
||||||
"""Configure provider to return streaming chunks."""
|
"""Configure provider to return streaming chunks."""
|
||||||
self._streaming_chunks = chunks
|
self._streaming_chunks = chunks
|
||||||
self._default_response = ''.join(chunks)
|
self._default_response = "".join(chunks)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def raises(self, error: Exception) -> 'FakeProvider':
|
def raises(self, error: Exception) -> "FakeProvider":
|
||||||
"""Configure provider to raise an error."""
|
"""Configure provider to raise an error."""
|
||||||
self._raise_error = error
|
self._raise_error = error
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def timeout(self) -> 'FakeProvider':
|
def timeout(self) -> "FakeProvider":
|
||||||
"""Configure provider to simulate timeout."""
|
"""Configure provider to simulate timeout."""
|
||||||
return self.raises(TimeoutError('Provider timeout'))
|
return self.raises(TimeoutError("Provider timeout"))
|
||||||
|
|
||||||
def auth_error(self) -> 'FakeProvider':
|
def auth_error(self) -> "FakeProvider":
|
||||||
"""Configure provider to simulate auth error."""
|
"""Configure provider to simulate auth error."""
|
||||||
return self.raises(Exception('Invalid API key'))
|
return self.raises(Exception("Invalid API key"))
|
||||||
|
|
||||||
def rate_limit(self) -> 'FakeProvider':
|
def rate_limit(self) -> "FakeProvider":
|
||||||
"""Configure provider to simulate rate limit."""
|
"""Configure provider to simulate rate limit."""
|
||||||
return self.raises(Exception('Rate limit exceeded'))
|
return self.raises(Exception("Rate limit exceeded"))
|
||||||
|
|
||||||
def malformed(self) -> 'FakeProvider':
|
def malformed(self) -> "FakeProvider":
|
||||||
"""Configure provider to simulate malformed response."""
|
"""Configure provider to simulate malformed response."""
|
||||||
self._default_response = None
|
self._default_response = None
|
||||||
return self
|
return self
|
||||||
@@ -87,7 +87,7 @@ class FakeProvider:
|
|||||||
def _create_message(self, content: str) -> provider_message.Message:
|
def _create_message(self, content: str) -> provider_message.Message:
|
||||||
"""Create a provider message from text content."""
|
"""Create a provider message from text content."""
|
||||||
return provider_message.Message(
|
return provider_message.Message(
|
||||||
role='assistant',
|
role="assistant",
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ class FakeProvider:
|
|||||||
) -> provider_message.MessageChunk:
|
) -> provider_message.MessageChunk:
|
||||||
"""Create a provider message chunk."""
|
"""Create a provider message chunk."""
|
||||||
return provider_message.MessageChunk(
|
return provider_message.MessageChunk(
|
||||||
role='assistant',
|
role="assistant",
|
||||||
content=content,
|
content=content,
|
||||||
is_final=is_final,
|
is_final=is_final,
|
||||||
msg_sequence=msg_sequence,
|
msg_sequence=msg_sequence,
|
||||||
@@ -116,15 +116,13 @@ class FakeProvider:
|
|||||||
) -> provider_message.Message:
|
) -> provider_message.Message:
|
||||||
"""Simulate non-streaming LLM invocation."""
|
"""Simulate non-streaming LLM invocation."""
|
||||||
# Capture request for assertions
|
# Capture request for assertions
|
||||||
self._captured_requests.append(
|
self._captured_requests.append({
|
||||||
{
|
"query_id": query.query_id if query else None,
|
||||||
'query_id': query.query_id if query else None,
|
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||||
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
"messages": messages,
|
||||||
'messages': messages,
|
"funcs": funcs,
|
||||||
'funcs': funcs,
|
"extra_args": extra_args,
|
||||||
'extra_args': extra_args,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate error if configured
|
# Simulate error if configured
|
||||||
if self._raise_error:
|
if self._raise_error:
|
||||||
@@ -133,7 +131,7 @@ class FakeProvider:
|
|||||||
# Return response
|
# Return response
|
||||||
if self._default_response is None:
|
if self._default_response is None:
|
||||||
# Malformed response
|
# Malformed response
|
||||||
return provider_message.Message(role='assistant', content=None)
|
return provider_message.Message(role="assistant", content=None)
|
||||||
|
|
||||||
return self._create_message(self._default_response)
|
return self._create_message(self._default_response)
|
||||||
|
|
||||||
@@ -148,16 +146,14 @@ class FakeProvider:
|
|||||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||||
"""Simulate streaming LLM invocation."""
|
"""Simulate streaming LLM invocation."""
|
||||||
# Capture request for assertions
|
# Capture request for assertions
|
||||||
self._captured_requests.append(
|
self._captured_requests.append({
|
||||||
{
|
"query_id": query.query_id if query else None,
|
||||||
'query_id': query.query_id if query else None,
|
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||||
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
"messages": messages,
|
||||||
'messages': messages,
|
"funcs": funcs,
|
||||||
'funcs': funcs,
|
"extra_args": extra_args,
|
||||||
'extra_args': extra_args,
|
"streaming": True,
|
||||||
'streaming': True,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate error if configured
|
# Simulate error if configured
|
||||||
if self._raise_error:
|
if self._raise_error:
|
||||||
@@ -165,12 +161,12 @@ class FakeProvider:
|
|||||||
|
|
||||||
# Yield chunks
|
# Yield chunks
|
||||||
for i, chunk in enumerate(self._streaming_chunks):
|
for i, chunk in enumerate(self._streaming_chunks):
|
||||||
is_final = i == len(self._streaming_chunks) - 1
|
is_final = (i == len(self._streaming_chunks) - 1)
|
||||||
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
|
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
|
||||||
|
|
||||||
|
|
||||||
def fake_provider(
|
def fake_provider(
|
||||||
default_response: str = 'fake response',
|
default_response: str = "fake response",
|
||||||
) -> FakeProvider:
|
) -> FakeProvider:
|
||||||
"""Create a FakeProvider with optional default response."""
|
"""Create a FakeProvider with optional default response."""
|
||||||
return FakeProvider(default_response=default_response)
|
return FakeProvider(default_response=default_response)
|
||||||
@@ -206,8 +202,8 @@ def fake_provider_malformed() -> FakeProvider:
|
|||||||
|
|
||||||
def fake_model(
|
def fake_model(
|
||||||
*,
|
*,
|
||||||
uuid: str = 'test-model-uuid',
|
uuid: str = "test-model-uuid",
|
||||||
name: str = 'test-model',
|
name: str = "test-model",
|
||||||
abilities: list[str] = None,
|
abilities: list[str] = None,
|
||||||
provider: FakeProvider = None,
|
provider: FakeProvider = None,
|
||||||
) -> Mock:
|
) -> Mock:
|
||||||
@@ -216,7 +212,7 @@ def fake_model(
|
|||||||
model.model_entity = Mock()
|
model.model_entity = Mock()
|
||||||
model.model_entity.uuid = uuid
|
model.model_entity.uuid = uuid
|
||||||
model.model_entity.name = name
|
model.model_entity.name = name
|
||||||
model.model_entity.abilities = abilities or ['func_call', 'vision']
|
model.model_entity.abilities = abilities or ["func_call", "vision"]
|
||||||
model.model_entity.extra_args = {}
|
model.model_entity.extra_args = {}
|
||||||
|
|
||||||
# Attach fake provider
|
# Attach fake provider
|
||||||
@@ -225,4 +221,4 @@ def fake_model(
|
|||||||
|
|
||||||
model.provider = provider
|
model.provider = provider
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -3,4 +3,4 @@ Integration tests package.
|
|||||||
|
|
||||||
These tests validate real system behavior with actual database/network resources.
|
These tests validate real system behavior with actual database/network resources.
|
||||||
Run with: uv run pytest tests/integration/ -m "not slow" -q
|
Run with: uv run pytest tests/integration/ -m "not slow" -q
|
||||||
"""
|
"""
|
||||||
@@ -2,4 +2,4 @@
|
|||||||
API integration tests package.
|
API integration tests package.
|
||||||
|
|
||||||
Tests for HTTP API endpoints using Quart test client.
|
Tests for HTTP API endpoints using Quart test client.
|
||||||
"""
|
"""
|
||||||
@@ -48,7 +48,6 @@ def mock_circular_import_chain():
|
|||||||
clear=clear,
|
clear=clear,
|
||||||
):
|
):
|
||||||
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -57,12 +56,10 @@ def fake_bot_app():
|
|||||||
"""Create FakeApp with bot services (module scope for reuse)."""
|
"""Create FakeApp with bot services (module scope for reuse)."""
|
||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auth services
|
# Auth services
|
||||||
app.user_service = Mock()
|
app.user_service = Mock()
|
||||||
@@ -74,29 +71,28 @@ def fake_bot_app():
|
|||||||
|
|
||||||
# Bot service
|
# Bot service
|
||||||
app.bot_service = Mock()
|
app.bot_service = Mock()
|
||||||
app.bot_service.get_bots = AsyncMock(
|
app.bot_service.get_bots = AsyncMock(return_value=[
|
||||||
return_value=[
|
{
|
||||||
{
|
|
||||||
'uuid': 'test-bot-uuid',
|
|
||||||
'name': 'Test Bot',
|
|
||||||
'platform': 'telegram',
|
|
||||||
'pipeline_uuid': 'test-pipeline-uuid',
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
app.bot_service.get_runtime_bot_info = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
'uuid': 'test-bot-uuid',
|
'uuid': 'test-bot-uuid',
|
||||||
'name': 'Test Bot',
|
'name': 'Test Bot',
|
||||||
'platform': 'telegram',
|
'platform': 'telegram',
|
||||||
'pipeline_uuid': 'test-pipeline-uuid',
|
'pipeline_uuid': 'test-pipeline-uuid',
|
||||||
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
|
|
||||||
}
|
}
|
||||||
)
|
])
|
||||||
|
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
|
||||||
|
'uuid': 'test-bot-uuid',
|
||||||
|
'name': 'Test Bot',
|
||||||
|
'platform': 'telegram',
|
||||||
|
'pipeline_uuid': 'test-pipeline-uuid',
|
||||||
|
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
|
||||||
|
})
|
||||||
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
|
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
|
||||||
app.bot_service.update_bot = AsyncMock(return_value={})
|
app.bot_service.update_bot = AsyncMock(return_value={})
|
||||||
app.bot_service.delete_bot = AsyncMock()
|
app.bot_service.delete_bot = AsyncMock()
|
||||||
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1))
|
app.bot_service.list_event_logs = AsyncMock(return_value=(
|
||||||
|
[{'uuid': 'log-1', 'message': 'test log'}],
|
||||||
|
1
|
||||||
|
))
|
||||||
app.bot_service.send_message = AsyncMock()
|
app.bot_service.send_message = AsyncMock()
|
||||||
|
|
||||||
# Platform manager
|
# Platform manager
|
||||||
@@ -122,7 +118,10 @@ class TestBotEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_bots_success(self, quart_test_client):
|
async def test_get_bots_success(self, quart_test_client):
|
||||||
"""GET /api/v1/platform/bots returns bot list."""
|
"""GET /api/v1/platform/bots returns bot list."""
|
||||||
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'})
|
response = await quart_test_client.get(
|
||||||
|
'/api/v1/platform/bots',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = await response.get_json()
|
data = await response.get_json()
|
||||||
@@ -136,7 +135,7 @@ class TestBotEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/platform/bots',
|
'/api/v1/platform/bots',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'},
|
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -148,7 +147,8 @@ class TestBotEndpoints:
|
|||||||
async def test_get_single_bot_success(self, quart_test_client):
|
async def test_get_single_bot_success(self, quart_test_client):
|
||||||
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
|
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/platform/bots/test-bot-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -162,7 +162,7 @@ class TestBotEndpoints:
|
|||||||
response = await quart_test_client.put(
|
response = await quart_test_client.put(
|
||||||
'/api/v1/platform/bots/test-bot-uuid',
|
'/api/v1/platform/bots/test-bot-uuid',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'Updated Bot'},
|
json={'name': 'Updated Bot'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -173,7 +173,8 @@ class TestBotEndpoints:
|
|||||||
async def test_delete_bot_success(self, quart_test_client):
|
async def test_delete_bot_success(self, quart_test_client):
|
||||||
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
|
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
|
||||||
response = await quart_test_client.delete(
|
response = await quart_test_client.delete(
|
||||||
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/platform/bots/test-bot-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -189,7 +190,7 @@ class TestBotLogsEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/platform/bots/test-bot-uuid/logs',
|
'/api/v1/platform/bots/test-bot-uuid/logs',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'from_index': -1, 'max_count': 10},
|
json={'from_index': -1, 'max_count': 10}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -212,8 +213,8 @@ class TestBotSendMessageEndpoint:
|
|||||||
json={
|
json={
|
||||||
'target_type': 'person',
|
'target_type': 'person',
|
||||||
'target_id': 'user123',
|
'target_id': 'user123',
|
||||||
'message_chain': [{'type': 'text', 'text': 'Hello'}],
|
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -227,7 +228,7 @@ class TestBotSendMessageEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
||||||
headers={'Authorization': 'Bearer test_api_key'},
|
headers={'Authorization': 'Bearer test_api_key'},
|
||||||
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]},
|
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -243,8 +244,8 @@ class TestBotSendMessageEndpoint:
|
|||||||
json={
|
json={
|
||||||
'target_type': 'invalid',
|
'target_type': 'invalid',
|
||||||
'target_id': 'user123',
|
'target_id': 'user123',
|
||||||
'message_chain': [{'type': 'text', 'text': 'Hello'}],
|
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ def mock_circular_import_chain():
|
|||||||
clear=clear,
|
clear=clear,
|
||||||
):
|
):
|
||||||
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -56,12 +55,10 @@ def fake_embed_app():
|
|||||||
"""Create FakeApp with embed widget services (module scope)."""
|
"""Create FakeApp with embed widget services (module scope)."""
|
||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create mock web_page_bot with valid UUID format
|
# Create mock web_page_bot with valid UUID format
|
||||||
mock_bot_entity = Mock()
|
mock_bot_entity = Mock()
|
||||||
@@ -86,7 +83,9 @@ def fake_embed_app():
|
|||||||
|
|
||||||
# WebSocket proxy bot with adapter
|
# WebSocket proxy bot with adapter
|
||||||
mock_websocket_adapter = Mock()
|
mock_websocket_adapter = Mock()
|
||||||
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}])
|
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
|
||||||
|
{'id': 'msg-1', 'content': 'test message'}
|
||||||
|
])
|
||||||
mock_websocket_adapter.reset_session = Mock()
|
mock_websocket_adapter.reset_session = Mock()
|
||||||
mock_websocket_adapter.handle_websocket_message = AsyncMock()
|
mock_websocket_adapter.handle_websocket_message = AsyncMock()
|
||||||
|
|
||||||
@@ -118,7 +117,9 @@ class TestEmbedWidgetEndpoint:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_widget_js_success(self, quart_test_client):
|
async def test_get_widget_js_success(self, quart_test_client):
|
||||||
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
|
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
|
||||||
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js')
|
response = await quart_test_client.get(
|
||||||
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert 'javascript' in response.content_type
|
assert 'javascript' in response.content_type
|
||||||
@@ -126,14 +127,18 @@ class TestEmbedWidgetEndpoint:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
|
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
|
||||||
"""GET widget.js with invalid UUID returns 400."""
|
"""GET widget.js with invalid UUID returns 400."""
|
||||||
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js')
|
response = await quart_test_client.get(
|
||||||
|
'/api/v1/embed/invalid-uuid/widget.js'
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_widget_js_bot_not_found(self, quart_test_client):
|
async def test_get_widget_js_bot_not_found(self, quart_test_client):
|
||||||
"""GET widget.js for non-existent bot returns 404."""
|
"""GET widget.js for non-existent bot returns 404."""
|
||||||
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js')
|
response = await quart_test_client.get(
|
||||||
|
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
@@ -159,7 +164,8 @@ class TestEmbedTurnstileVerifyEndpoint:
|
|||||||
async def test_turnstile_verify_no_secret(self, quart_test_client):
|
async def test_turnstile_verify_no_secret(self, quart_test_client):
|
||||||
"""POST turnstile verify without secret returns dummy token."""
|
"""POST turnstile verify without secret returns dummy token."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'}
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||||
|
json={'token': 'test-token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -171,7 +177,8 @@ class TestEmbedTurnstileVerifyEndpoint:
|
|||||||
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
|
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
|
||||||
"""POST turnstile verify with invalid UUID returns 400."""
|
"""POST turnstile verify with invalid UUID returns 400."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'}
|
'/api/v1/embed/invalid-uuid/turnstile/verify',
|
||||||
|
json={'token': 'test-token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -180,7 +187,8 @@ class TestEmbedTurnstileVerifyEndpoint:
|
|||||||
async def test_turnstile_verify_missing_token(self, quart_test_client):
|
async def test_turnstile_verify_missing_token(self, quart_test_client):
|
||||||
"""POST turnstile verify without token returns 400."""
|
"""POST turnstile verify without token returns 400."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={}
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||||
|
json={}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -195,7 +203,7 @@ class TestEmbedMessagesEndpoint:
|
|||||||
"""GET messages/person returns messages."""
|
"""GET messages/person returns messages."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -208,7 +216,7 @@ class TestEmbedMessagesEndpoint:
|
|||||||
"""GET messages/group returns messages."""
|
"""GET messages/group returns messages."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -218,7 +226,7 @@ class TestEmbedMessagesEndpoint:
|
|||||||
"""GET messages with invalid session_type returns 400."""
|
"""GET messages with invalid session_type returns 400."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -233,7 +241,7 @@ class TestEmbedResetEndpoint:
|
|||||||
"""POST reset/person resets session."""
|
"""POST reset/person resets session."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -244,7 +252,8 @@ class TestEmbedResetEndpoint:
|
|||||||
async def test_reset_session_invalid_uuid(self, quart_test_client):
|
async def test_reset_session_invalid_uuid(self, quart_test_client):
|
||||||
"""POST reset with invalid UUID returns 400."""
|
"""POST reset with invalid UUID returns 400."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'}
|
'/api/v1/embed/invalid-uuid/reset/person',
|
||||||
|
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -260,7 +269,7 @@ class TestEmbedFeedbackEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||||
json={'message_id': 'msg-123', 'feedback_type': 1},
|
json={'message_id': 'msg-123', 'feedback_type': 1}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -274,7 +283,7 @@ class TestEmbedFeedbackEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||||
json={'message_id': 'msg-123', 'feedback_type': 2},
|
json={'message_id': 'msg-123', 'feedback_type': 2}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -285,7 +294,7 @@ class TestEmbedFeedbackEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||||
json={'message_id': 'msg-123', 'feedback_type': 99},
|
json={'message_id': 'msg-123', 'feedback_type': 99}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ def mock_circular_import_chain():
|
|||||||
clear=clear,
|
clear=clear,
|
||||||
):
|
):
|
||||||
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -58,12 +57,10 @@ def fake_knowledge_app():
|
|||||||
"""Create FakeApp with knowledge services (module scope for reuse)."""
|
"""Create FakeApp with knowledge services (module scope for reuse)."""
|
||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auth services
|
# Auth services
|
||||||
app.user_service = Mock()
|
app.user_service = Mock()
|
||||||
@@ -75,35 +72,33 @@ def fake_knowledge_app():
|
|||||||
|
|
||||||
# Knowledge service
|
# Knowledge service
|
||||||
app.knowledge_service = Mock()
|
app.knowledge_service = Mock()
|
||||||
app.knowledge_service.get_knowledge_bases = AsyncMock(
|
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
|
||||||
return_value=[
|
{
|
||||||
{
|
|
||||||
'uuid': 'test-kb-uuid',
|
|
||||||
'name': 'Test Knowledge Base',
|
|
||||||
'description': 'Test KB description',
|
|
||||||
'engine_plugin_id': 'test/engine',
|
|
||||||
'created_at': '2024-01-01T00:00:00',
|
|
||||||
'updated_at': '2024-01-01T00:00:00',
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
app.knowledge_service.get_knowledge_base = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
'uuid': 'test-kb-uuid',
|
'uuid': 'test-kb-uuid',
|
||||||
'name': 'Test Knowledge Base',
|
'name': 'Test Knowledge Base',
|
||||||
'description': 'Test KB description',
|
'description': 'Test KB description',
|
||||||
'engine_plugin_id': 'test/engine',
|
'engine_plugin_id': 'test/engine',
|
||||||
|
'created_at': '2024-01-01T00:00:00',
|
||||||
|
'updated_at': '2024-01-01T00:00:00',
|
||||||
}
|
}
|
||||||
)
|
])
|
||||||
|
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
|
||||||
|
'uuid': 'test-kb-uuid',
|
||||||
|
'name': 'Test Knowledge Base',
|
||||||
|
'description': 'Test KB description',
|
||||||
|
'engine_plugin_id': 'test/engine',
|
||||||
|
})
|
||||||
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
|
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
|
||||||
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
|
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
|
||||||
app.knowledge_service.delete_knowledge_base = AsyncMock()
|
app.knowledge_service.delete_knowledge_base = AsyncMock()
|
||||||
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(
|
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
|
||||||
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}]
|
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
|
||||||
)
|
])
|
||||||
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
|
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
|
||||||
app.knowledge_service.delete_file = AsyncMock()
|
app.knowledge_service.delete_file = AsyncMock()
|
||||||
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}])
|
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
|
||||||
|
{'content': 'test result', 'score': 0.95}
|
||||||
|
])
|
||||||
|
|
||||||
# RAG manager
|
# RAG manager
|
||||||
app.rag_mgr = Mock()
|
app.rag_mgr = Mock()
|
||||||
@@ -129,7 +124,8 @@ class TestKnowledgeBaseEndpoints:
|
|||||||
async def test_get_knowledge_bases_success(self, quart_test_client):
|
async def test_get_knowledge_bases_success(self, quart_test_client):
|
||||||
"""GET /api/v1/knowledge/bases returns knowledge base list."""
|
"""GET /api/v1/knowledge/bases returns knowledge base list."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/knowledge/bases',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -144,7 +140,7 @@ class TestKnowledgeBaseEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/knowledge/bases',
|
'/api/v1/knowledge/bases',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'},
|
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -156,7 +152,8 @@ class TestKnowledgeBaseEndpoints:
|
|||||||
async def test_get_single_knowledge_base_success(self, quart_test_client):
|
async def test_get_single_knowledge_base_success(self, quart_test_client):
|
||||||
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
|
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -170,7 +167,7 @@ class TestKnowledgeBaseEndpoints:
|
|||||||
response = await quart_test_client.put(
|
response = await quart_test_client.put(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'Updated KB'},
|
json={'name': 'Updated KB'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -181,7 +178,8 @@ class TestKnowledgeBaseEndpoints:
|
|||||||
async def test_delete_knowledge_base_success(self, quart_test_client):
|
async def test_delete_knowledge_base_success(self, quart_test_client):
|
||||||
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
|
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
|
||||||
response = await quart_test_client.delete(
|
response = await quart_test_client.delete(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -195,7 +193,8 @@ class TestKnowledgeBaseFilesEndpoints:
|
|||||||
async def test_get_files_success(self, quart_test_client):
|
async def test_get_files_success(self, quart_test_client):
|
||||||
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
|
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -209,7 +208,7 @@ class TestKnowledgeBaseFilesEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'},
|
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -221,7 +220,8 @@ class TestKnowledgeBaseFilesEndpoints:
|
|||||||
async def test_delete_file_from_knowledge_base(self, quart_test_client):
|
async def test_delete_file_from_knowledge_base(self, quart_test_client):
|
||||||
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
|
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
|
||||||
response = await quart_test_client.delete(
|
response = await quart_test_client.delete(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}},
|
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -249,7 +249,9 @@ class TestKnowledgeBaseRetrieveEndpoint:
|
|||||||
async def test_retrieve_without_query_returns_error(self, quart_test_client):
|
async def test_retrieve_without_query_returns_error(self, quart_test_client):
|
||||||
"""POST retrieve without query returns 400."""
|
"""POST retrieve without query returns 400."""
|
||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={}
|
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||||
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
|
json={}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ def mock_circular_import_chain():
|
|||||||
clear=clear,
|
clear=clear,
|
||||||
):
|
):
|
||||||
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -55,12 +54,10 @@ def fake_monitoring_app():
|
|||||||
"""Create FakeApp with monitoring services (module scope)."""
|
"""Create FakeApp with monitoring services (module scope)."""
|
||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
|
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
|
||||||
app.user_service = Mock()
|
app.user_service = Mock()
|
||||||
@@ -70,34 +67,40 @@ def fake_monitoring_app():
|
|||||||
|
|
||||||
# Monitoring service
|
# Monitoring service
|
||||||
app.monitoring_service = Mock()
|
app.monitoring_service = Mock()
|
||||||
app.monitoring_service.get_overview_metrics = AsyncMock(
|
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
|
||||||
return_value={
|
'total_messages': 100,
|
||||||
'total_messages': 100,
|
'total_llm_calls': 50,
|
||||||
'total_llm_calls': 50,
|
'total_sessions': 20,
|
||||||
'total_sessions': 20,
|
'active_sessions': 5,
|
||||||
'active_sessions': 5,
|
'total_errors': 2,
|
||||||
'total_errors': 2,
|
})
|
||||||
}
|
app.monitoring_service.get_messages = AsyncMock(return_value=(
|
||||||
)
|
[{'id': 'msg-1', 'content': 'test'}], 100
|
||||||
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100))
|
))
|
||||||
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50))
|
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
|
||||||
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10))
|
[{'id': 'llm-1'}], 50
|
||||||
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20))
|
))
|
||||||
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2))
|
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
|
||||||
app.monitoring_service.get_session_analysis = AsyncMock(
|
[{'id': 'emb-1'}], 10
|
||||||
return_value={
|
))
|
||||||
'found': True,
|
app.monitoring_service.get_sessions = AsyncMock(return_value=(
|
||||||
'session_id': 'sess-1',
|
[{'session_id': 'sess-1'}], 20
|
||||||
}
|
))
|
||||||
)
|
app.monitoring_service.get_errors = AsyncMock(return_value=(
|
||||||
app.monitoring_service.get_message_details = AsyncMock(
|
[{'id': 'err-1'}], 2
|
||||||
return_value={
|
))
|
||||||
'found': True,
|
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
|
||||||
'message_id': 'msg-1',
|
'found': True,
|
||||||
}
|
'session_id': 'sess-1',
|
||||||
)
|
})
|
||||||
|
app.monitoring_service.get_message_details = AsyncMock(return_value={
|
||||||
|
'found': True,
|
||||||
|
'message_id': 'msg-1',
|
||||||
|
})
|
||||||
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
|
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
|
||||||
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12))
|
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
|
||||||
|
[{'feedback_id': 'fb-1'}], 12
|
||||||
|
))
|
||||||
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
|
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
|
||||||
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
|
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
|
||||||
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
|
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
|
||||||
@@ -127,7 +130,8 @@ class TestMonitoringOverviewEndpoint:
|
|||||||
async def test_get_overview_success(self, quart_test_client):
|
async def test_get_overview_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/overview returns metrics."""
|
"""GET /api/v1/monitoring/overview returns metrics."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/overview',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -143,7 +147,8 @@ class TestMonitoringMessagesEndpoint:
|
|||||||
async def test_get_messages_success(self, quart_test_client):
|
async def test_get_messages_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/messages returns message list."""
|
"""GET /api/v1/monitoring/messages returns message list."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/messages',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -160,7 +165,8 @@ class TestMonitoringLLMCallsEndpoint:
|
|||||||
async def test_get_llm_calls_success(self, quart_test_client):
|
async def test_get_llm_calls_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/llm-calls."""
|
"""GET /api/v1/monitoring/llm-calls."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/llm-calls',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -174,7 +180,8 @@ class TestMonitoringEmbeddingCallsEndpoint:
|
|||||||
async def test_get_embedding_calls_success(self, quart_test_client):
|
async def test_get_embedding_calls_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/embedding-calls."""
|
"""GET /api/v1/monitoring/embedding-calls."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/embedding-calls',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -188,7 +195,8 @@ class TestMonitoringSessionsEndpoint:
|
|||||||
async def test_get_sessions_success(self, quart_test_client):
|
async def test_get_sessions_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/sessions."""
|
"""GET /api/v1/monitoring/sessions."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/sessions',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -202,7 +210,8 @@ class TestMonitoringErrorsEndpoint:
|
|||||||
async def test_get_errors_success(self, quart_test_client):
|
async def test_get_errors_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/errors."""
|
"""GET /api/v1/monitoring/errors."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/errors',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -216,7 +225,8 @@ class TestMonitoringAllDataEndpoint:
|
|||||||
async def test_get_all_data_success(self, quart_test_client):
|
async def test_get_all_data_success(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/data returns all data."""
|
"""GET /api/v1/monitoring/data returns all data."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/data',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -232,7 +242,8 @@ class TestMonitoringDetailsEndpoints:
|
|||||||
async def test_get_session_analysis(self, quart_test_client):
|
async def test_get_session_analysis(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
|
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/sessions/sess-1/analysis',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -241,7 +252,8 @@ class TestMonitoringDetailsEndpoints:
|
|||||||
async def test_get_message_details(self, quart_test_client):
|
async def test_get_message_details(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/messages/{id}/details."""
|
"""GET /api/v1/monitoring/messages/{id}/details."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/messages/msg-1/details',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -255,7 +267,8 @@ class TestMonitoringFeedbackEndpoints:
|
|||||||
async def test_get_feedback_stats(self, quart_test_client):
|
async def test_get_feedback_stats(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/feedback/stats."""
|
"""GET /api/v1/monitoring/feedback/stats."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/feedback/stats',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -264,7 +277,8 @@ class TestMonitoringFeedbackEndpoints:
|
|||||||
async def test_get_feedback_list(self, quart_test_client):
|
async def test_get_feedback_list(self, quart_test_client):
|
||||||
"""GET /api/v1/monitoring/feedback."""
|
"""GET /api/v1/monitoring/feedback."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/feedback',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -278,7 +292,8 @@ class TestMonitoringExportEndpoint:
|
|||||||
async def test_export_messages(self, quart_test_client):
|
async def test_export_messages(self, quart_test_client):
|
||||||
"""GET export?type=messages returns CSV."""
|
"""GET export?type=messages returns CSV."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/export?type=messages',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -288,7 +303,8 @@ class TestMonitoringExportEndpoint:
|
|||||||
async def test_export_llm_calls(self, quart_test_client):
|
async def test_export_llm_calls(self, quart_test_client):
|
||||||
"""GET export?type=llm-calls returns CSV."""
|
"""GET export?type=llm-calls returns CSV."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/export?type=llm-calls',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -297,7 +313,8 @@ class TestMonitoringExportEndpoint:
|
|||||||
async def test_export_sessions(self, quart_test_client):
|
async def test_export_sessions(self, quart_test_client):
|
||||||
"""GET export?type=sessions returns CSV."""
|
"""GET export?type=sessions returns CSV."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/export?type=sessions',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -306,7 +323,8 @@ class TestMonitoringExportEndpoint:
|
|||||||
async def test_export_feedback(self, quart_test_client):
|
async def test_export_feedback(self, quart_test_client):
|
||||||
"""GET export?type=feedback returns CSV."""
|
"""GET export?type=feedback returns CSV."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/monitoring/export?type=feedback',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ def mock_circular_import_chain():
|
|||||||
):
|
):
|
||||||
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
|
||||||
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
|
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -58,12 +57,10 @@ def fake_provider_app():
|
|||||||
"""Create FakeApp with provider/model services (module scope for reuse)."""
|
"""Create FakeApp with provider/model services (module scope for reuse)."""
|
||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auth services
|
# Auth services
|
||||||
app.user_service = Mock()
|
app.user_service = Mock()
|
||||||
@@ -75,23 +72,27 @@ def fake_provider_app():
|
|||||||
|
|
||||||
# Provider service
|
# Provider service
|
||||||
app.provider_service = Mock()
|
app.provider_service = Mock()
|
||||||
app.provider_service.get_providers = AsyncMock(
|
app.provider_service.get_providers = AsyncMock(return_value=[
|
||||||
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}]
|
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
|
||||||
)
|
])
|
||||||
app.provider_service.get_provider = AsyncMock(
|
app.provider_service.get_provider = AsyncMock(return_value={
|
||||||
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
|
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
|
||||||
)
|
})
|
||||||
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
|
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
|
||||||
app.provider_service.update_provider = AsyncMock(return_value={})
|
app.provider_service.update_provider = AsyncMock(return_value={})
|
||||||
app.provider_service.delete_provider = AsyncMock()
|
app.provider_service.delete_provider = AsyncMock()
|
||||||
app.provider_service.get_provider_model_counts = AsyncMock(
|
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
|
||||||
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0}
|
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
|
||||||
)
|
})
|
||||||
|
|
||||||
# LLM model service
|
# LLM model service
|
||||||
app.llm_model_service = Mock()
|
app.llm_model_service = Mock()
|
||||||
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}])
|
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
|
||||||
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'})
|
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
|
||||||
|
])
|
||||||
|
app.llm_model_service.get_llm_model = AsyncMock(return_value={
|
||||||
|
'uuid': 'test-model-uuid', 'name': 'gpt-4'
|
||||||
|
})
|
||||||
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
|
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
|
||||||
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
|
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
|
||||||
app.llm_model_service.delete_llm_model = AsyncMock()
|
app.llm_model_service.delete_llm_model = AsyncMock()
|
||||||
@@ -132,7 +133,8 @@ class TestProviderEndpoints:
|
|||||||
async def test_get_providers_success(self, quart_test_client):
|
async def test_get_providers_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/providers returns provider list with complete structure."""
|
"""GET /api/v1/provider/providers returns provider list with complete structure."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/providers',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -155,7 +157,8 @@ class TestProviderEndpoints:
|
|||||||
async def test_get_single_provider_success(self, quart_test_client):
|
async def test_get_single_provider_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
|
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/providers/test-provider-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -174,7 +177,7 @@ class TestProviderEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/provider/providers',
|
'/api/v1/provider/providers',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New Provider', 'requester': 'chatcmpl'},
|
json={'name': 'New Provider', 'requester': 'chatcmpl'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -191,7 +194,7 @@ class TestProviderEndpoints:
|
|||||||
response = await quart_test_client.put(
|
response = await quart_test_client.put(
|
||||||
'/api/v1/provider/providers/test-provider-uuid',
|
'/api/v1/provider/providers/test-provider-uuid',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'Updated Provider'},
|
json={'name': 'Updated Provider'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -202,7 +205,8 @@ class TestProviderEndpoints:
|
|||||||
async def test_delete_provider_success(self, quart_test_client):
|
async def test_delete_provider_success(self, quart_test_client):
|
||||||
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
|
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
|
||||||
response = await quart_test_client.delete(
|
response = await quart_test_client.delete(
|
||||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/providers/test-provider-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -211,7 +215,8 @@ class TestProviderEndpoints:
|
|||||||
async def test_get_provider_includes_model_counts(self, quart_test_client):
|
async def test_get_provider_includes_model_counts(self, quart_test_client):
|
||||||
"""GET provider response includes model counts."""
|
"""GET provider response includes model counts."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/providers/test-provider-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -232,7 +237,8 @@ class TestModelEndpoints:
|
|||||||
async def test_get_llm_models_success(self, quart_test_client):
|
async def test_get_llm_models_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/models/llm returns model list."""
|
"""GET /api/v1/provider/models/llm returns model list."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/models/llm',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -244,7 +250,8 @@ class TestModelEndpoints:
|
|||||||
async def test_get_single_llm_model_success(self, quart_test_client):
|
async def test_get_single_llm_model_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
|
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/models/llm/test-model-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -257,7 +264,7 @@ class TestModelEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/provider/models/llm',
|
'/api/v1/provider/models/llm',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'},
|
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -269,7 +276,8 @@ class TestModelEndpoints:
|
|||||||
async def test_delete_llm_model_success(self, quart_test_client):
|
async def test_delete_llm_model_success(self, quart_test_client):
|
||||||
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
|
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
|
||||||
response = await quart_test_client.delete(
|
response = await quart_test_client.delete(
|
||||||
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/models/llm/test-model-uuid',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -283,7 +291,8 @@ class TestEmbeddingModelEndpoints:
|
|||||||
async def test_get_embedding_models_success(self, quart_test_client):
|
async def test_get_embedding_models_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/models/embedding returns model list."""
|
"""GET /api/v1/provider/models/embedding returns model list."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/models/embedding',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -297,7 +306,7 @@ class TestEmbeddingModelEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/provider/models/embedding',
|
'/api/v1/provider/models/embedding',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'},
|
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -314,7 +323,8 @@ class TestRerankModelEndpoints:
|
|||||||
async def test_get_rerank_models_success(self, quart_test_client):
|
async def test_get_rerank_models_success(self, quart_test_client):
|
||||||
"""GET /api/v1/provider/models/rerank returns model list."""
|
"""GET /api/v1/provider/models/rerank returns model list."""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}
|
'/api/v1/provider/models/rerank',
|
||||||
|
headers={'Authorization': 'Bearer test_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -328,7 +338,7 @@ class TestRerankModelEndpoints:
|
|||||||
response = await quart_test_client.post(
|
response = await quart_test_client.post(
|
||||||
'/api/v1/provider/models/rerank',
|
'/api/v1/provider/models/rerank',
|
||||||
headers={'Authorization': 'Bearer test_token'},
|
headers={'Authorization': 'Bearer test_token'},
|
||||||
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'},
|
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ pytestmark = pytest.mark.integration
|
|||||||
|
|
||||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def mock_circular_import_chain():
|
def mock_circular_import_chain():
|
||||||
"""
|
"""
|
||||||
@@ -70,7 +69,6 @@ def mock_circular_import_chain():
|
|||||||
|
|
||||||
# ============== FAKE APPLICATION FOR API TESTS ==============
|
# ============== FAKE APPLICATION FOR API TESTS ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fake_api_app():
|
def fake_api_app():
|
||||||
"""
|
"""
|
||||||
@@ -81,14 +79,12 @@ def fake_api_app():
|
|||||||
app = FakeApp()
|
app = FakeApp()
|
||||||
|
|
||||||
# API-specific config
|
# API-specific config
|
||||||
app.instance_config.data.update(
|
app.instance_config.data.update({
|
||||||
{
|
'api': {'port': 5300},
|
||||||
'api': {'port': 5300},
|
'plugin': {'enable_marketplace': True},
|
||||||
'plugin': {'enable_marketplace': True},
|
'space': {'url': 'https://space.langbot.app'},
|
||||||
'space': {'url': 'https://space.langbot.app'},
|
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# API-specific services
|
# API-specific services
|
||||||
app.user_service = Mock()
|
app.user_service = Mock()
|
||||||
@@ -122,7 +118,6 @@ def fake_api_app():
|
|||||||
|
|
||||||
# ============== QUART TEST CLIENT FIXTURE ==============
|
# ============== QUART TEST CLIENT FIXTURE ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def quart_test_client(fake_api_app, http_controller_cls):
|
async def quart_test_client(fake_api_app, http_controller_cls):
|
||||||
"""
|
"""
|
||||||
@@ -140,7 +135,6 @@ async def quart_test_client(fake_api_app, http_controller_cls):
|
|||||||
|
|
||||||
# ============== API SMOKE TESTS ==============
|
# ============== API SMOKE TESTS ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||||
class TestHealthEndpoint:
|
class TestHealthEndpoint:
|
||||||
"""Tests for /healthz endpoint - simplest smoke test."""
|
"""Tests for /healthz endpoint - simplest smoke test."""
|
||||||
@@ -228,7 +222,8 @@ class TestProtectedEndpoints:
|
|||||||
Protected endpoint returns 401 with invalid token.
|
Protected endpoint returns 401 with invalid token.
|
||||||
"""
|
"""
|
||||||
response = await quart_test_client.get(
|
response = await quart_test_client.get(
|
||||||
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'}
|
'/api/v1/user/check-token',
|
||||||
|
headers={'Authorization': 'Bearer invalid_token'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -259,7 +254,10 @@ class TestInvalidPayload:
|
|||||||
"""
|
"""
|
||||||
POST with wrong JSON structure returns stable error.
|
POST with wrong JSON structure returns stable error.
|
||||||
"""
|
"""
|
||||||
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'})
|
response = await quart_test_client.post(
|
||||||
|
'/api/v1/user/auth',
|
||||||
|
json={'wrong_field': 'value'}
|
||||||
|
)
|
||||||
|
|
||||||
# Should return error with stable JSON structure
|
# Should return error with stable JSON structure
|
||||||
assert response.status_code in (400, 500, 401)
|
assert response.status_code in (400, 500, 401)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
Persistence integration tests package.
|
Persistence integration tests package.
|
||||||
|
|
||||||
Tests for database migrations and storage behavior.
|
Tests for database migrations and storage behavior.
|
||||||
"""
|
"""
|
||||||
@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sqlite_db_url(tmp_path):
|
def sqlite_db_url(tmp_path):
|
||||||
"""Create SQLite URL with temporary database file."""
|
"""Create SQLite URL with temporary database file."""
|
||||||
db_file = tmp_path / 'test_migrations.db'
|
db_file = tmp_path / "test_migrations.db"
|
||||||
return f'sqlite+aiosqlite:///{db_file}'
|
return f"sqlite+aiosqlite:///{db_file}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
|
|||||||
|
|
||||||
# Verify revision
|
# Verify revision
|
||||||
rev = await get_alembic_current(sqlite_engine)
|
rev = await get_alembic_current(sqlite_engine)
|
||||||
assert rev is not None, 'Expected a revision after upgrade'
|
assert rev is not None, "Expected a revision after upgrade"
|
||||||
# Head should be the latest migration
|
# Head should be the latest migration
|
||||||
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
|
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_upgrade_idempotent(self, sqlite_engine):
|
async def test_upgrade_idempotent(self, sqlite_engine):
|
||||||
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
|
|||||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||||
|
|
||||||
rev2 = await get_alembic_current(sqlite_engine)
|
rev2 = await get_alembic_current(sqlite_engine)
|
||||||
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
|
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||||
|
|
||||||
|
|
||||||
class TestSQLiteMigrationFreshDatabase:
|
class TestSQLiteMigrationFreshDatabase:
|
||||||
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
|
|||||||
4. Verify revision
|
4. Verify revision
|
||||||
"""
|
"""
|
||||||
# Use different DB file for fresh test
|
# Use different DB file for fresh test
|
||||||
fresh_db_file = tmp_path / 'test_migrations_fresh.db'
|
fresh_db_file = tmp_path / "test_migrations_fresh.db"
|
||||||
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
|
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||||
fresh_engine = create_async_engine(fresh_url)
|
fresh_engine = create_async_engine(fresh_url)
|
||||||
|
|
||||||
# Create tables on fresh DB
|
# Create tables on fresh DB
|
||||||
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
|
|||||||
|
|
||||||
# Verify revision
|
# Verify revision
|
||||||
rev = await get_alembic_current(fresh_engine)
|
rev = await get_alembic_current(fresh_engine)
|
||||||
assert rev is not None, 'Expected a revision on fresh DB'
|
assert rev is not None, "Expected a revision on fresh DB"
|
||||||
|
|
||||||
await fresh_engine.dispose()
|
await fresh_engine.dispose()
|
||||||
|
|
||||||
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
|
|||||||
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
|
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
|
||||||
any arbitrary failure with try-except pass.
|
any arbitrary failure with try-except pass.
|
||||||
"""
|
"""
|
||||||
fresh_db_file = tmp_path / 'test_empty_migrations.db'
|
fresh_db_file = tmp_path / "test_empty_migrations.db"
|
||||||
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
|
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||||
fresh_engine = create_async_engine(fresh_url)
|
fresh_engine = create_async_engine(fresh_url)
|
||||||
|
|
||||||
# Capture the actual behavior
|
# Capture the actual behavior
|
||||||
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
|
|||||||
# Verify specific behavior - one of two outcomes is expected
|
# Verify specific behavior - one of two outcomes is expected
|
||||||
if actual_result is not None:
|
if actual_result is not None:
|
||||||
# Migration succeeded - verify revision exists
|
# Migration succeeded - verify revision exists
|
||||||
assert actual_result is not None, 'Revision should exist after successful migration'
|
assert actual_result is not None, "Revision should exist after successful migration"
|
||||||
else:
|
else:
|
||||||
# Migration failed - verify the error type is known
|
# Migration failed - verify the error type is known
|
||||||
# Alembic typically raises specific errors for missing tables
|
# Alembic typically raises specific errors for missing tables
|
||||||
assert actual_error is not None, 'Error should be captured if migration failed'
|
assert actual_error is not None, "Error should be captured if migration failed"
|
||||||
# Log the error type for documentation (don't silently pass)
|
# Log the error type for documentation (don't silently pass)
|
||||||
error_type = type(actual_error).__name__
|
error_type = type(actual_error).__name__
|
||||||
# Acceptable error types for empty DB scenarios
|
# Acceptable error types for empty DB scenarios
|
||||||
acceptable_errors = [
|
acceptable_errors = [
|
||||||
'OperationalError', # SQLite table not found
|
'OperationalError', # SQLite table not found
|
||||||
'ProgrammingError', # SQLAlchemy errors
|
'ProgrammingError', # SQLAlchemy errors
|
||||||
'CommandError', # Alembic command errors
|
'CommandError', # Alembic command errors
|
||||||
]
|
]
|
||||||
assert error_type in acceptable_errors, (
|
assert error_type in acceptable_errors, (
|
||||||
f'Unexpected error type: {error_type}. '
|
f"Unexpected error type: {error_type}. "
|
||||||
f'This may indicate a regression in migration behavior. '
|
f"This may indicate a regression in migration behavior. "
|
||||||
f'Error: {actual_error}'
|
f"Error: {actual_error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
|
|||||||
|
|
||||||
# No stamp - should return None
|
# No stamp - should return None
|
||||||
rev = await get_alembic_current(sqlite_engine)
|
rev = await get_alembic_current(sqlite_engine)
|
||||||
assert rev is None, f'Expected None for unstamped DB, got {rev}'
|
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
|
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
|
||||||
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
|
|||||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||||
|
|
||||||
rev = await get_alembic_current(sqlite_engine)
|
rev = await get_alembic_current(sqlite_engine)
|
||||||
assert rev == '0001_baseline'
|
assert rev == '0001_baseline'
|
||||||
@@ -34,14 +34,14 @@ def postgres_url():
|
|||||||
"""Get PostgreSQL URL from environment."""
|
"""Get PostgreSQL URL from environment."""
|
||||||
url = os.environ.get('TEST_POSTGRES_URL')
|
url = os.environ.get('TEST_POSTGRES_URL')
|
||||||
if not url:
|
if not url:
|
||||||
pytest.skip('TEST_POSTGRES_URL not set')
|
pytest.skip("TEST_POSTGRES_URL not set")
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def postgres_engine(postgres_url):
|
async def postgres_engine(postgres_url):
|
||||||
"""Create async PostgreSQL engine."""
|
"""Create async PostgreSQL engine."""
|
||||||
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT')
|
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
|
||||||
yield engine
|
yield engine
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
|
|||||||
async with postgres_engine.begin() as conn:
|
async with postgres_engine.begin() as conn:
|
||||||
# Drop alembic_version table if exists
|
# Drop alembic_version table if exists
|
||||||
try:
|
try:
|
||||||
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
|
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
|
|||||||
|
|
||||||
async with postgres_engine.begin() as conn:
|
async with postgres_engine.begin() as conn:
|
||||||
try:
|
try:
|
||||||
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
|
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -83,7 +83,9 @@ class TestPostgreSQLMigrationBaseline:
|
|||||||
"""Tests for baseline stamp workflow on PostgreSQL."""
|
"""Tests for baseline stamp workflow on PostgreSQL."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version):
|
async def test_postgres_baseline_stamp_sets_revision(
|
||||||
|
self, postgres_engine, clean_tables, clean_alembic_version
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Stamp baseline on existing tables sets correct revision.
|
Stamp baseline on existing tables sets correct revision.
|
||||||
|
|
||||||
@@ -104,7 +106,9 @@ class TestPostgreSQLMigrationBaseline:
|
|||||||
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version):
|
async def test_postgres_baseline_stamp_on_empty_db(
|
||||||
|
self, postgres_engine, clean_tables, clean_alembic_version
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Stamp on empty database (no tables) still sets revision.
|
Stamp on empty database (no tables) still sets revision.
|
||||||
|
|
||||||
@@ -121,7 +125,9 @@ class TestPostgreSQLMigrationUpgrade:
|
|||||||
"""Tests for upgrade to head workflow on PostgreSQL."""
|
"""Tests for upgrade to head workflow on PostgreSQL."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version):
|
async def test_postgres_upgrade_from_baseline_to_head(
|
||||||
|
self, postgres_engine, clean_tables, clean_alembic_version
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Upgrade from baseline to head applies all migrations.
|
Upgrade from baseline to head applies all migrations.
|
||||||
|
|
||||||
@@ -143,12 +149,14 @@ class TestPostgreSQLMigrationUpgrade:
|
|||||||
|
|
||||||
# Verify revision
|
# Verify revision
|
||||||
rev = await get_alembic_current(postgres_engine)
|
rev = await get_alembic_current(postgres_engine)
|
||||||
assert rev is not None, 'Expected a revision after upgrade'
|
assert rev is not None, "Expected a revision after upgrade"
|
||||||
# Head should be the latest migration (0005 for current state)
|
# Head should be the latest migration (0005 for current state)
|
||||||
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
|
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version):
|
async def test_postgres_upgrade_idempotent(
|
||||||
|
self, postgres_engine, clean_tables, clean_alembic_version
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Running upgrade to head multiple times is idempotent.
|
Running upgrade to head multiple times is idempotent.
|
||||||
|
|
||||||
@@ -172,7 +180,7 @@ class TestPostgreSQLMigrationUpgrade:
|
|||||||
await run_alembic_upgrade(postgres_engine, 'head')
|
await run_alembic_upgrade(postgres_engine, 'head')
|
||||||
|
|
||||||
rev2 = await get_alembic_current(postgres_engine)
|
rev2 = await get_alembic_current(postgres_engine)
|
||||||
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
|
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||||
|
|
||||||
|
|
||||||
class TestPostgreSQLMigrationGetCurrent:
|
class TestPostgreSQLMigrationGetCurrent:
|
||||||
@@ -191,7 +199,7 @@ class TestPostgreSQLMigrationGetCurrent:
|
|||||||
|
|
||||||
# No stamp - should return None
|
# No stamp - should return None
|
||||||
rev = await get_alembic_current(postgres_engine)
|
rev = await get_alembic_current(postgres_engine)
|
||||||
assert rev is None, f'Expected None for unstamped DB, got {rev}'
|
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgres_get_current_after_stamp_returns_revision(
|
async def test_postgres_get_current_after_stamp_returns_revision(
|
||||||
@@ -206,4 +214,4 @@ class TestPostgreSQLMigrationGetCurrent:
|
|||||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||||
|
|
||||||
rev = await get_alembic_current(postgres_engine)
|
rev = await get_alembic_current(postgres_engine)
|
||||||
assert rev == '0001_baseline'
|
assert rev == '0001_baseline'
|
||||||
@@ -2,4 +2,4 @@
|
|||||||
Pipeline integration tests package.
|
Pipeline integration tests package.
|
||||||
|
|
||||||
Tests for full pipeline flow using fake provider/runner.
|
Tests for full pipeline flow using fake provider/runner.
|
||||||
"""
|
"""
|
||||||
@@ -26,7 +26,6 @@ pytestmark = pytest.mark.integration
|
|||||||
|
|
||||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def mock_circular_import_chain():
|
def mock_circular_import_chain():
|
||||||
"""
|
"""
|
||||||
@@ -104,7 +103,6 @@ def mock_circular_import_chain():
|
|||||||
|
|
||||||
# ============== FAKE RUNNER ==============
|
# ============== FAKE RUNNER ==============
|
||||||
|
|
||||||
|
|
||||||
class FakeRunner:
|
class FakeRunner:
|
||||||
"""Minimal fake runner class for pipeline integration tests.
|
"""Minimal fake runner class for pipeline integration tests.
|
||||||
|
|
||||||
@@ -119,13 +117,12 @@ class FakeRunner:
|
|||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
self._provider = FakeProvider()
|
self._provider = FakeProvider()
|
||||||
# Instance-level configuration set via class attribute
|
# Instance-level configuration set via class attribute
|
||||||
self._response_text = 'fake response'
|
self._response_text = "fake response"
|
||||||
self._raise_error = None
|
self._raise_error = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def returns(cls, text: str):
|
def returns(cls, text: str):
|
||||||
"""Create a runner class configured to return specific text."""
|
"""Create a runner class configured to return specific text."""
|
||||||
|
|
||||||
# We create a subclass with configured response
|
# We create a subclass with configured response
|
||||||
class ConfiguredRunner(cls):
|
class ConfiguredRunner(cls):
|
||||||
name = cls.name
|
name = cls.name
|
||||||
@@ -135,13 +132,11 @@ class FakeRunner:
|
|||||||
def __init__(self, app=None, config=None):
|
def __init__(self, app=None, config=None):
|
||||||
super().__init__(app, config)
|
super().__init__(app, config)
|
||||||
self._response_text = text
|
self._response_text = text
|
||||||
|
|
||||||
return ConfiguredRunner
|
return ConfiguredRunner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def raises(cls, error: Exception):
|
def raises(cls, error: Exception):
|
||||||
"""Create a runner class configured to raise an error."""
|
"""Create a runner class configured to raise an error."""
|
||||||
|
|
||||||
class ConfiguredRunner(cls):
|
class ConfiguredRunner(cls):
|
||||||
name = cls.name
|
name = cls.name
|
||||||
_response_text = None
|
_response_text = None
|
||||||
@@ -150,7 +145,6 @@ class FakeRunner:
|
|||||||
def __init__(self, app=None, config=None):
|
def __init__(self, app=None, config=None):
|
||||||
super().__init__(app, config)
|
super().__init__(app, config)
|
||||||
self._raise_error = error
|
self._raise_error = error
|
||||||
|
|
||||||
return ConfiguredRunner
|
return ConfiguredRunner
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
@@ -167,7 +161,6 @@ class FakeRunner:
|
|||||||
|
|
||||||
# ============== PIPELINE APP FIXTURE ==============
|
# ============== PIPELINE APP FIXTURE ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pipeline_app():
|
def pipeline_app():
|
||||||
"""
|
"""
|
||||||
@@ -194,7 +187,6 @@ def pipeline_app():
|
|||||||
def __init__(self, name, messages):
|
def __init__(self, name, messages):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return MockPrompt(self.name, list(self.messages))
|
return MockPrompt(self.name, list(self.messages))
|
||||||
|
|
||||||
@@ -245,17 +237,14 @@ def fake_platform_adapter():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def set_fake_runner():
|
def set_fake_runner():
|
||||||
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
|
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
|
||||||
|
|
||||||
def _set_runner(runner_cls):
|
def _set_runner(runner_cls):
|
||||||
# preregistered_runners expects a list of runner classes
|
# preregistered_runners expects a list of runner classes
|
||||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
|
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
|
||||||
|
|
||||||
return _set_runner
|
return _set_runner
|
||||||
|
|
||||||
|
|
||||||
# ============== PIPELINE CONFIGURATION ==============
|
# ============== PIPELINE CONFIGURATION ==============
|
||||||
|
|
||||||
|
|
||||||
def create_minimal_pipeline_config():
|
def create_minimal_pipeline_config():
|
||||||
"""Create minimal pipeline configuration for tests."""
|
"""Create minimal pipeline configuration for tests."""
|
||||||
return {
|
return {
|
||||||
@@ -284,7 +273,6 @@ def create_minimal_pipeline_config():
|
|||||||
|
|
||||||
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
|
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
|
||||||
|
|
||||||
|
|
||||||
async def collect_processor_results(processor, query, stage_name):
|
async def collect_processor_results(processor, query, stage_name):
|
||||||
"""
|
"""
|
||||||
Helper to handle the coroutine -> async_generator pattern.
|
Helper to handle the coroutine -> async_generator pattern.
|
||||||
@@ -308,7 +296,6 @@ async def collect_processor_results(processor, query, stage_name):
|
|||||||
|
|
||||||
# ============== TESTS ==============
|
# ============== TESTS ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||||
class TestPipelineStageChainReal:
|
class TestPipelineStageChainReal:
|
||||||
"""Tests for real pipeline stage chain."""
|
"""Tests for real pipeline stage chain."""
|
||||||
@@ -350,7 +337,7 @@ class TestPreProcessorStage:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Create query with adapter
|
# Create query with adapter
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
|
|
||||||
@@ -378,7 +365,7 @@ class TestPreProcessorStage:
|
|||||||
|
|
||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
query = text_query('test message content')
|
query = text_query("test message content")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
|
|
||||||
@@ -409,11 +396,11 @@ class TestProcessorStage:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Set fake runner that returns pong
|
# Set fake runner that returns pong
|
||||||
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
|
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||||
set_fake_runner(fake_runner)
|
set_fake_runner(fake_runner)
|
||||||
|
|
||||||
# Create query
|
# Create query
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
query.resp_messages = []
|
query.resp_messages = []
|
||||||
@@ -427,7 +414,6 @@ class TestProcessorStage:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -446,7 +432,7 @@ class TestProcessorStage:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Create query
|
# Create query
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
|
|
||||||
@@ -459,7 +445,6 @@ class TestProcessorStage:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -477,13 +462,13 @@ class TestProcessorStage:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Create query
|
# Create query
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
query.resp_messages = []
|
query.resp_messages = []
|
||||||
|
|
||||||
# Create reply chain
|
# Create reply chain
|
||||||
reply_chain = text_chain('plugin response')
|
reply_chain = text_chain("plugin response")
|
||||||
|
|
||||||
# Mock plugin_connector to prevent default with reply
|
# Mock plugin_connector to prevent default with reply
|
||||||
mock_event_ctx = Mock()
|
mock_event_ctx = Mock()
|
||||||
@@ -494,7 +479,6 @@ class TestProcessorStage:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -518,7 +502,7 @@ class TestRunnerExceptionFlow:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Set fake runner that raises exception
|
# Set fake runner that raises exception
|
||||||
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded'))
|
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
|
||||||
set_fake_runner(fake_runner)
|
set_fake_runner(fake_runner)
|
||||||
|
|
||||||
# Create query with exception handling config
|
# Create query with exception handling config
|
||||||
@@ -526,7 +510,7 @@ class TestRunnerExceptionFlow:
|
|||||||
config['output']['misc']['exception-handling'] = 'show-hint'
|
config['output']['misc']['exception-handling'] = 'show-hint'
|
||||||
config['output']['misc']['failure-hint'] = 'Request failed.'
|
config['output']['misc']['failure-hint'] = 'Request failed.'
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = config
|
query.pipeline_config = config
|
||||||
|
|
||||||
@@ -539,7 +523,6 @@ class TestRunnerExceptionFlow:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -558,14 +541,14 @@ class TestRunnerExceptionFlow:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Set fake runner that raises specific exception
|
# Set fake runner that raises specific exception
|
||||||
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error'))
|
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
|
||||||
set_fake_runner(fake_runner)
|
set_fake_runner(fake_runner)
|
||||||
|
|
||||||
# Create query with show-error mode
|
# Create query with show-error mode
|
||||||
config = create_minimal_pipeline_config()
|
config = create_minimal_pipeline_config()
|
||||||
config['output']['misc']['exception-handling'] = 'show-error'
|
config['output']['misc']['exception-handling'] = 'show-error'
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = config
|
query.pipeline_config = config
|
||||||
|
|
||||||
@@ -578,7 +561,6 @@ class TestRunnerExceptionFlow:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -596,14 +578,14 @@ class TestRunnerExceptionFlow:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Set fake runner that raises exception
|
# Set fake runner that raises exception
|
||||||
fake_runner = FakeRunner().raises(Exception('Hidden error'))
|
fake_runner = FakeRunner().raises(Exception("Hidden error"))
|
||||||
set_fake_runner(fake_runner)
|
set_fake_runner(fake_runner)
|
||||||
|
|
||||||
# Create query with hide mode
|
# Create query with hide mode
|
||||||
config = create_minimal_pipeline_config()
|
config = create_minimal_pipeline_config()
|
||||||
config['output']['misc']['exception-handling'] = 'hide'
|
config['output']['misc']['exception-handling'] = 'hide'
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = config
|
query.pipeline_config = config
|
||||||
|
|
||||||
@@ -616,7 +598,6 @@ class TestRunnerExceptionFlow:
|
|||||||
|
|
||||||
# Create Processor stage
|
# Create Processor stage
|
||||||
from langbot.pkg.pipeline.process import process
|
from langbot.pkg.pipeline.process import process
|
||||||
|
|
||||||
processor_stage = process.Processor(pipeline_app)
|
processor_stage = process.Processor(pipeline_app)
|
||||||
await processor_stage.initialize(query.pipeline_config)
|
await processor_stage.initialize(query.pipeline_config)
|
||||||
|
|
||||||
@@ -642,7 +623,7 @@ class TestSendResponseBackStage:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Create query with response message
|
# Create query with response message
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
|
|
||||||
@@ -685,12 +666,12 @@ class TestStageChainIntegration:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Set fake runner
|
# Set fake runner
|
||||||
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
|
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||||
set_fake_runner(fake_runner)
|
set_fake_runner(fake_runner)
|
||||||
|
|
||||||
# Create query
|
# Create query
|
||||||
config = create_minimal_pipeline_config()
|
config = create_minimal_pipeline_config()
|
||||||
query = text_query('ping')
|
query = text_query("ping")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = config
|
query.pipeline_config = config
|
||||||
query.resp_messages = []
|
query.resp_messages = []
|
||||||
@@ -709,7 +690,7 @@ class TestStageChainIntegration:
|
|||||||
|
|
||||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -730,7 +711,6 @@ class TestStageChainIntegration:
|
|||||||
|
|
||||||
# Build resp_message_chain from resp_messages
|
# Build resp_message_chain from resp_messages
|
||||||
from tests.factories.message import text_chain
|
from tests.factories.message import text_chain
|
||||||
|
|
||||||
for resp_msg in query.resp_messages:
|
for resp_msg in query.resp_messages:
|
||||||
if resp_msg.content:
|
if resp_msg.content:
|
||||||
query.resp_message_chain.append(text_chain(resp_msg.content))
|
query.resp_message_chain.append(text_chain(resp_msg.content))
|
||||||
@@ -757,7 +737,7 @@ class TestStageChainIntegration:
|
|||||||
adapter, platform = fake_platform_adapter
|
adapter, platform = fake_platform_adapter
|
||||||
|
|
||||||
# Create query
|
# Create query
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.adapter = adapter
|
query.adapter = adapter
|
||||||
query.pipeline_config = create_minimal_pipeline_config()
|
query.pipeline_config = create_minimal_pipeline_config()
|
||||||
|
|
||||||
@@ -774,7 +754,7 @@ class TestStageChainIntegration:
|
|||||||
|
|
||||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -795,4 +775,4 @@ class TestStageChainIntegration:
|
|||||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||||
|
|
||||||
# Chain stops here - no resp_messages
|
# Chain stops here - no resp_messages
|
||||||
assert len(query.resp_messages) == 0
|
assert len(query.resp_messages) == 0
|
||||||
@@ -3,4 +3,4 @@ Smoke tests package.
|
|||||||
|
|
||||||
Smoke tests verify basic functionality works without testing edge cases.
|
Smoke tests verify basic functionality works without testing edge cases.
|
||||||
Run with: uv run pytest tests/smoke/ -q
|
Run with: uv run pytest tests/smoke/ -q
|
||||||
"""
|
"""
|
||||||
@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
|
|||||||
assert app.instance_config is not None
|
assert app.instance_config is not None
|
||||||
|
|
||||||
# Verify default config
|
# Verify default config
|
||||||
assert app.instance_config.data['command']['prefix'] == ['/', '!']
|
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
|
||||||
assert app.instance_config.data['command']['enable'] is True
|
assert app.instance_config.data["command"]["enable"] is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_provider_returns_text(self):
|
async def test_fake_provider_returns_text(self):
|
||||||
"""Test FakeProvider returns configured response."""
|
"""Test FakeProvider returns configured response."""
|
||||||
provider = FakeProvider(default_response='test response')
|
provider = FakeProvider(default_response="test response")
|
||||||
|
|
||||||
# Create mock model with provider
|
# Create mock model with provider
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
|
|
||||||
# Create a simple query
|
# Create a simple query
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
# Simulate invoke
|
# Simulate invoke
|
||||||
result = await provider.invoke_llm(
|
result = await provider.invoke_llm(
|
||||||
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.role == 'assistant'
|
assert result.role == "assistant"
|
||||||
assert result.content == 'test response'
|
assert result.content == "test response"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_provider_pong(self):
|
async def test_fake_provider_pong(self):
|
||||||
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
|
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
|
||||||
provider = fake_provider_pong()
|
provider = fake_provider_pong()
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
query = text_query('ping')
|
query = text_query("ping")
|
||||||
|
|
||||||
result = await provider.invoke_llm(
|
result = await provider.invoke_llm(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_provider_streaming(self):
|
async def test_fake_provider_streaming(self):
|
||||||
"""Test FakeProvider streaming response."""
|
"""Test FakeProvider streaming response."""
|
||||||
provider = FakeProvider().returns_streaming(['Hello', ' World'])
|
provider = FakeProvider().returns_streaming(["Hello", " World"])
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
# invoke_llm_stream returns an async generator, don't await it
|
# invoke_llm_stream returns an async generator, don't await it
|
||||||
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
|
|||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
assert len(chunks) == 2
|
assert len(chunks) == 2
|
||||||
assert chunks[0].content == 'Hello'
|
assert chunks[0].content == "Hello"
|
||||||
assert chunks[1].content == ' World'
|
assert chunks[1].content == " World"
|
||||||
assert chunks[1].is_final is True
|
assert chunks[1].is_final is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
|
|||||||
"""Test FakeProvider simulates timeout error."""
|
"""Test FakeProvider simulates timeout error."""
|
||||||
provider = FakeProvider().timeout()
|
provider = FakeProvider().timeout()
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
with pytest.raises(TimeoutError, match='Provider timeout'):
|
with pytest.raises(TimeoutError, match="Provider timeout"):
|
||||||
await provider.invoke_llm(
|
await provider.invoke_llm(
|
||||||
query=query,
|
query=query,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
|
|||||||
"""Test FakeProvider simulates rate limit error."""
|
"""Test FakeProvider simulates rate limit error."""
|
||||||
provider = FakeProvider().rate_limit()
|
provider = FakeProvider().rate_limit()
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
with pytest.raises(Exception, match='Rate limit exceeded'):
|
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||||
await provider.invoke_llm(
|
await provider.invoke_llm(
|
||||||
query=query,
|
query=query,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
|
|||||||
async def test_fake_provider_captures_requests(self):
|
async def test_fake_provider_captures_requests(self):
|
||||||
"""Test FakeProvider captures request arguments."""
|
"""Test FakeProvider captures request arguments."""
|
||||||
provider = FakeProvider()
|
provider = FakeProvider()
|
||||||
model = fake_model(name='gpt-4', provider=provider)
|
model = fake_model(name="gpt-4", provider=provider)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
await provider.invoke_llm(
|
await provider.invoke_llm(
|
||||||
query=query,
|
query=query,
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{'role': 'user', 'content': 'hello'}],
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
funcs=[{'name': 'test_func'}],
|
funcs=[{"name": "test_func"}],
|
||||||
extra_args={'temperature': 0.7},
|
extra_args={"temperature": 0.7},
|
||||||
)
|
)
|
||||||
|
|
||||||
captured = provider.get_captured_requests()
|
captured = provider.get_captured_requests()
|
||||||
assert len(captured) == 1
|
assert len(captured) == 1
|
||||||
assert captured[0]['model'] == 'gpt-4'
|
assert captured[0]["model"] == "gpt-4"
|
||||||
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}]
|
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
|
||||||
assert captured[0]['funcs'] == [{'name': 'test_func'}]
|
assert captured[0]["funcs"] == [{"name": "test_func"}]
|
||||||
assert captured[0]['extra_args'] == {'temperature': 0.7}
|
assert captured[0]["extra_args"] == {"temperature": 0.7}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_platform_capture_outbound(self):
|
async def test_fake_platform_capture_outbound(self):
|
||||||
"""Test FakePlatform captures outbound messages."""
|
"""Test FakePlatform captures outbound messages."""
|
||||||
platform = FakePlatform(bot_account_id='test-bot')
|
platform = FakePlatform(bot_account_id="test-bot")
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
# Simulate sending reply
|
# Simulate sending reply
|
||||||
from tests.factories.message import text_chain
|
from tests.factories.message import text_chain
|
||||||
|
|
||||||
reply_chain = text_chain('response text')
|
reply_chain = text_chain("response text")
|
||||||
event = query.message_event
|
event = query.message_event
|
||||||
|
|
||||||
await platform.reply_message(event, reply_chain, quote_origin=False)
|
await platform.reply_message(event, reply_chain, quote_origin=False)
|
||||||
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
|
|||||||
# Verify captured
|
# Verify captured
|
||||||
outbound = platform.get_outbound_messages()
|
outbound = platform.get_outbound_messages()
|
||||||
assert len(outbound) == 1
|
assert len(outbound) == 1
|
||||||
assert outbound[0]['type'] == 'reply'
|
assert outbound[0]["type"] == "reply"
|
||||||
assert outbound[0]['message'] == reply_chain
|
assert outbound[0]["message"] == reply_chain
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_platform_friend_message(self):
|
async def test_fake_platform_friend_message(self):
|
||||||
"""Test FakePlatform creates friend message events."""
|
"""Test FakePlatform creates friend message events."""
|
||||||
platform = FakePlatform(bot_account_id='test-bot')
|
platform = FakePlatform(bot_account_id="test-bot")
|
||||||
|
|
||||||
event = platform.create_friend_message(
|
event = platform.create_friend_message(
|
||||||
text='hello bot',
|
text="hello bot",
|
||||||
sender_id=12345,
|
sender_id=12345,
|
||||||
nickname='TestUser',
|
nickname="TestUser",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert event.type == 'FriendMessage'
|
assert event.type == "FriendMessage"
|
||||||
assert event.sender.id == 12345
|
assert event.sender.id == 12345
|
||||||
assert event.sender.nickname == 'TestUser'
|
assert event.sender.nickname == "TestUser"
|
||||||
assert str(event.message_chain) == 'hello bot'
|
assert str(event.message_chain) == "hello bot"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_platform_group_message_with_mention(self):
|
async def test_fake_platform_group_message_with_mention(self):
|
||||||
"""Test FakePlatform creates group message with @mention."""
|
"""Test FakePlatform creates group message with @mention."""
|
||||||
platform = FakePlatform(bot_account_id='test-bot')
|
platform = FakePlatform(bot_account_id="test-bot")
|
||||||
|
|
||||||
event = platform.create_group_message(
|
event = platform.create_group_message(
|
||||||
text='hello everyone',
|
text="hello everyone",
|
||||||
sender_id=12345,
|
sender_id=12345,
|
||||||
group_id=99999,
|
group_id=99999,
|
||||||
mention_bot=True,
|
mention_bot=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert event.type == 'GroupMessage'
|
assert event.type == "GroupMessage"
|
||||||
assert event.sender.id == 12345
|
assert event.sender.id == 12345
|
||||||
assert event.group.id == 99999
|
assert event.group.id == 99999
|
||||||
|
|
||||||
@@ -220,57 +220,54 @@ class TestFakeMessageFlow:
|
|||||||
async def test_query_factories_basic(self):
|
async def test_query_factories_basic(self):
|
||||||
"""Test basic query factory functions."""
|
"""Test basic query factory functions."""
|
||||||
# Text query
|
# Text query
|
||||||
q1 = text_query('hello world')
|
q1 = text_query("hello world")
|
||||||
assert q1.launcher_type.value == 'person'
|
assert q1.launcher_type.value == "person"
|
||||||
assert str(q1.message_chain) == 'hello world'
|
assert str(q1.message_chain) == "hello world"
|
||||||
|
|
||||||
# Group query
|
# Group query
|
||||||
from tests.factories import group_text_query
|
from tests.factories import group_text_query
|
||||||
|
q2 = group_text_query("hello group", group_id=88888)
|
||||||
q2 = group_text_query('hello group', group_id=88888)
|
assert q2.launcher_type.value == "group"
|
||||||
assert q2.launcher_type.value == 'group'
|
|
||||||
assert q2.launcher_id == 88888
|
assert q2.launcher_id == 88888
|
||||||
|
|
||||||
# Command query
|
# Command query
|
||||||
from tests.factories import command_query
|
from tests.factories import command_query
|
||||||
|
q3 = command_query("help", prefix="/")
|
||||||
q3 = command_query('help', prefix='/')
|
assert str(q3.message_chain) == "/help"
|
||||||
assert str(q3.message_chain) == '/help'
|
|
||||||
|
|
||||||
# Mention query
|
# Mention query
|
||||||
from tests.factories import mention_query
|
from tests.factories import mention_query
|
||||||
|
q4 = mention_query("hi", target="test-bot", group_id=77777)
|
||||||
q4 = mention_query('hi', target='test-bot', group_id=77777)
|
assert q4.launcher_type.value == "group"
|
||||||
assert q4.launcher_type.value == 'group'
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_platform_send_failure(self):
|
async def test_fake_platform_send_failure(self):
|
||||||
"""Test FakePlatform simulates send failure."""
|
"""Test FakePlatform simulates send failure."""
|
||||||
platform = FakePlatform().send_failure()
|
platform = FakePlatform().send_failure()
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
from tests.factories.message import text_chain
|
from tests.factories.message import text_chain
|
||||||
|
|
||||||
with pytest.raises(Exception, match='Platform send failure'):
|
with pytest.raises(Exception, match="Platform send failure"):
|
||||||
await platform.reply_message(
|
await platform.reply_message(
|
||||||
query.message_event,
|
query.message_event,
|
||||||
text_chain('response'),
|
text_chain("response"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mock_platform_adapter(self):
|
async def test_mock_platform_adapter(self):
|
||||||
"""Test mock_platform_adapter helper."""
|
"""Test mock_platform_adapter helper."""
|
||||||
platform = FakePlatform(bot_account_id='bot-123')
|
platform = FakePlatform(bot_account_id="bot-123")
|
||||||
adapter = mock_platform_adapter(platform)
|
adapter = mock_platform_adapter(platform)
|
||||||
|
|
||||||
assert adapter.bot_account_id == 'bot-123'
|
assert adapter.bot_account_id == "bot-123"
|
||||||
assert adapter._fake_platform is platform
|
assert adapter._fake_platform is platform
|
||||||
|
|
||||||
# Test reply_message is wired
|
# Test reply_message is wired
|
||||||
from tests.factories.message import text_chain
|
from tests.factories.message import text_chain
|
||||||
|
|
||||||
query = text_query('test')
|
query = text_query("test")
|
||||||
await adapter.reply_message(query.message_event, text_chain('response'))
|
await adapter.reply_message(query.message_event, text_chain("response"))
|
||||||
|
|
||||||
# Verify platform captured it
|
# Verify platform captured it
|
||||||
assert len(platform.get_outbound_messages()) == 1
|
assert len(platform.get_outbound_messages()) == 1
|
||||||
@@ -296,18 +293,18 @@ class TestMessageFlowIntegration:
|
|||||||
Note: This does NOT run actual LangBot pipeline stages.
|
Note: This does NOT run actual LangBot pipeline stages.
|
||||||
"""
|
"""
|
||||||
# Setup
|
# Setup
|
||||||
platform = FakePlatform(bot_account_id='test-bot')
|
platform = FakePlatform(bot_account_id="test-bot")
|
||||||
provider = fake_provider_pong()
|
provider = fake_provider_pong()
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
|
|
||||||
# Create inbound message
|
# Create inbound message
|
||||||
query = text_query('ping')
|
query = text_query("ping")
|
||||||
|
|
||||||
# Simulate provider processing
|
# Simulate provider processing
|
||||||
response = await provider.invoke_llm(
|
response = await provider.invoke_llm(
|
||||||
query=query,
|
query=query,
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{'role': 'user', 'content': 'ping'}],
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
funcs=[],
|
funcs=[],
|
||||||
extra_args={},
|
extra_args={},
|
||||||
)
|
)
|
||||||
@@ -324,16 +321,16 @@ class TestMessageFlowIntegration:
|
|||||||
# Verify platform captured outbound
|
# Verify platform captured outbound
|
||||||
outbound = platform.get_outbound_messages()
|
outbound = platform.get_outbound_messages()
|
||||||
assert len(outbound) == 1
|
assert len(outbound) == 1
|
||||||
assert outbound[0]['type'] == 'reply'
|
assert outbound[0]["type"] == "reply"
|
||||||
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE
|
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streaming_message_flow(self):
|
async def test_streaming_message_flow(self):
|
||||||
"""Smoke test: streaming message flow."""
|
"""Smoke test: streaming message flow."""
|
||||||
platform = FakePlatform().supports_streaming()
|
platform = FakePlatform().supports_streaming()
|
||||||
provider = FakeProvider().returns_streaming(['Hello', ' there'])
|
provider = FakeProvider().returns_streaming(["Hello", " there"])
|
||||||
model = fake_model(provider=provider)
|
model = fake_model(provider=provider)
|
||||||
query = text_query('hi')
|
query = text_query("hi")
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in provider.invoke_llm_stream(
|
async for chunk in provider.invoke_llm_stream(
|
||||||
@@ -347,8 +344,8 @@ class TestMessageFlowIntegration:
|
|||||||
|
|
||||||
# Verify streaming worked
|
# Verify streaming worked
|
||||||
assert len(chunks) == 2
|
assert len(chunks) == 2
|
||||||
full_content = ''.join(c.content for c in chunks)
|
full_content = "".join(c.content for c in chunks)
|
||||||
assert full_content == 'Hello there'
|
assert full_content == "Hello there"
|
||||||
|
|
||||||
# Verify platform supports streaming
|
# Verify platform supports streaming
|
||||||
assert await platform.is_stream_output_supported() is True
|
assert await platform.is_stream_output_supported() is True
|
||||||
@@ -15,12 +15,22 @@ import pathlib
|
|||||||
# Resolve project root (one level up from tests/)
|
# Resolve project root (one level up from tests/)
|
||||||
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py'
|
VULN_FILE = (
|
||||||
|
_PROJECT_ROOT
|
||||||
|
/ "src"
|
||||||
|
/ "langbot"
|
||||||
|
/ "pkg"
|
||||||
|
/ "api"
|
||||||
|
/ "http"
|
||||||
|
/ "controller"
|
||||||
|
/ "groups"
|
||||||
|
/ "system.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_no_exec_call_in_system_controller():
|
def test_no_exec_call_in_system_controller():
|
||||||
"""Verify there is no exec() call in system.py that takes user input."""
|
"""Verify there is no exec() call in system.py that takes user input."""
|
||||||
with open(VULN_FILE, 'r') as f:
|
with open(VULN_FILE, "r") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
|
|
||||||
tree = ast.parse(source)
|
tree = ast.parse(source)
|
||||||
@@ -30,26 +40,27 @@ def test_no_exec_call_in_system_controller():
|
|||||||
if isinstance(node, ast.Call):
|
if isinstance(node, ast.Call):
|
||||||
func = node.func
|
func = node.func
|
||||||
# Match bare exec() call
|
# Match bare exec() call
|
||||||
if isinstance(func, ast.Name) and func.id == 'exec':
|
if isinstance(func, ast.Name) and func.id == "exec":
|
||||||
exec_calls.append(node.lineno)
|
exec_calls.append(node.lineno)
|
||||||
|
|
||||||
assert len(exec_calls) == 0, (
|
assert len(exec_calls) == 0, (
|
||||||
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().'
|
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
|
||||||
|
"User-supplied code must never be passed to exec()."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_no_debug_exec_route():
|
def test_no_debug_exec_route():
|
||||||
"""Verify the /debug/exec route is not registered."""
|
"""Verify the /debug/exec route is not registered."""
|
||||||
with open(VULN_FILE, 'r') as f:
|
with open(VULN_FILE, "r") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
|
|
||||||
assert 'debug/exec' not in source, (
|
assert "debug/exec" not in source, (
|
||||||
'The /debug/exec route still exists in system.py. '
|
"The /debug/exec route still exists in system.py. "
|
||||||
'This endpoint allows arbitrary code execution and must be removed.'
|
"This endpoint allows arbitrary code execution and must be removed."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_no_exec_call_in_system_controller()
|
test_no_exec_call_in_system_controller()
|
||||||
test_no_debug_exec_route()
|
test_no_debug_exec_route()
|
||||||
print('All tests passed!')
|
print("All tests passed!")
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
"""Unit tests for LangBot API HTTP service layer."""
|
"""Unit tests for LangBot API HTTP service layer."""
|
||||||
@@ -13,4 +13,4 @@ Does NOT:
|
|||||||
- Call real provider/platform/network
|
- Call real provider/platform/network
|
||||||
|
|
||||||
Uses tests.factories.FakeApp as base mock application.
|
Uses tests.factories.FakeApp as base mock application.
|
||||||
"""
|
"""
|
||||||
@@ -132,7 +132,9 @@ class TestApiKeyServiceCreateApiKey:
|
|||||||
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
|
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
|
||||||
result = await service.create_api_key('New Key', 'Test description')
|
result = await service.create_api_key('New Key', 'Test description')
|
||||||
|
|
||||||
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}]
|
assert insert_params == [
|
||||||
|
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
|
||||||
|
]
|
||||||
assert result['key'].startswith('lbk_')
|
assert result['key'].startswith('lbk_')
|
||||||
assert result['key'] == 'lbk_fixed-token'
|
assert result['key'] == 'lbk_fixed-token'
|
||||||
assert result['name'] == 'New Key'
|
assert result['name'] == 'New Key'
|
||||||
|
|||||||
@@ -303,7 +303,13 @@ class TestBotServiceCreateBot:
|
|||||||
ap = SimpleNamespace()
|
ap = SimpleNamespace()
|
||||||
ap.persistence_mgr = SimpleNamespace()
|
ap.persistence_mgr = SimpleNamespace()
|
||||||
ap.instance_config = SimpleNamespace()
|
ap.instance_config = SimpleNamespace()
|
||||||
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}}
|
ap.instance_config.data = {
|
||||||
|
'system': {
|
||||||
|
'limitation': {
|
||||||
|
'max_bots': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ap.platform_mgr = SimpleNamespace()
|
ap.platform_mgr = SimpleNamespace()
|
||||||
ap.platform_mgr.load_bot = AsyncMock()
|
ap.platform_mgr.load_bot = AsyncMock()
|
||||||
|
|
||||||
@@ -312,7 +318,9 @@ class TestBotServiceCreateBot:
|
|||||||
bot2 = _create_mock_bot(bot_uuid='uuid-2')
|
bot2 = _create_mock_bot(bot_uuid='uuid-2')
|
||||||
mock_result = _create_mock_result([bot1, bot2])
|
mock_result = _create_mock_result([bot1, bot2])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = BotService(ap)
|
service = BotService(ap)
|
||||||
|
|
||||||
@@ -344,7 +352,6 @@ class TestBotServiceCreateBot:
|
|||||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -355,7 +362,9 @@ class TestBotServiceCreateBot:
|
|||||||
return bot_result # Get bot
|
return bot_result # Get bot
|
||||||
|
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
|
||||||
|
)
|
||||||
|
|
||||||
service = BotService(ap)
|
service = BotService(ap)
|
||||||
|
|
||||||
@@ -388,7 +397,6 @@ class TestBotServiceCreateBot:
|
|||||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -484,7 +492,6 @@ class TestBotServiceUpdateBot:
|
|||||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -575,9 +582,10 @@ class TestBotServiceListEventLogs:
|
|||||||
# Mock runtime bot with logger
|
# Mock runtime bot with logger
|
||||||
runtime_bot = SimpleNamespace()
|
runtime_bot = SimpleNamespace()
|
||||||
runtime_bot.logger = SimpleNamespace()
|
runtime_bot.logger = SimpleNamespace()
|
||||||
runtime_bot.logger.get_logs = AsyncMock(
|
runtime_bot.logger.get_logs = AsyncMock(return_value=(
|
||||||
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5)
|
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
|
||||||
)
|
5
|
||||||
|
))
|
||||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||||
|
|
||||||
service = BotService(ap)
|
service = BotService(ap)
|
||||||
@@ -638,7 +646,11 @@ class TestBotServiceSendMessage:
|
|||||||
service = BotService(ap)
|
service = BotService(ap)
|
||||||
|
|
||||||
# Execute with valid message chain format
|
# Execute with valid message chain format
|
||||||
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]}
|
message_chain_data = {
|
||||||
|
'messages': [
|
||||||
|
{'type': 'text', 'data': {'text': 'Hello'}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# Patch the import location - the module imports inside the function
|
# Patch the import location - the module imports inside the function
|
||||||
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
|
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Tests cover:
|
|||||||
- Knowledge engine discovery
|
- Knowledge engine discovery
|
||||||
- File operations
|
- File operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -53,7 +52,9 @@ class TestGetKnowledgeBases:
|
|||||||
"""Test that it returns all knowledge base details."""
|
"""Test that it returns all knowledge base details."""
|
||||||
knowledge_module = get_knowledge_service_module()
|
knowledge_module = get_knowledge_service_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}])
|
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
|
||||||
|
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
|
||||||
|
)
|
||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
result = await service.get_knowledge_bases()
|
result = await service.get_knowledge_bases()
|
||||||
@@ -82,7 +83,9 @@ class TestGetKnowledgeBase:
|
|||||||
"""Test that it returns specific KB details."""
|
"""Test that it returns specific KB details."""
|
||||||
knowledge_module = get_knowledge_service_module()
|
knowledge_module = get_knowledge_service_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'})
|
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||||
|
return_value={'uuid': 'kb1', 'name': 'KB1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
result = await service.get_knowledge_base('kb1')
|
result = await service.get_knowledge_base('kb1')
|
||||||
@@ -150,7 +153,9 @@ class TestCreateKnowledgeBase:
|
|||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
|
|
||||||
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'})
|
await service.create_knowledge_base({
|
||||||
|
'knowledge_engine_plugin_id': 'author/engine'
|
||||||
|
})
|
||||||
|
|
||||||
# Check that default name 'Untitled' was used
|
# Check that default name 'Untitled' was used
|
||||||
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
|
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
|
||||||
@@ -165,21 +170,20 @@ class TestUpdateKnowledgeBase:
|
|||||||
"""Test that only mutable fields are updated."""
|
"""Test that only mutable fields are updated."""
|
||||||
knowledge_module = get_knowledge_service_module()
|
knowledge_module = get_knowledge_service_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'})
|
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||||
|
return_value={'uuid': 'kb1', 'name': 'Updated'}
|
||||||
|
)
|
||||||
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
|
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
|
||||||
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
|
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
|
||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
|
|
||||||
# Pass both mutable and immutable fields
|
# Pass both mutable and immutable fields
|
||||||
await service.update_knowledge_base(
|
await service.update_knowledge_base('kb1', {
|
||||||
'kb1',
|
'name': 'New Name',
|
||||||
{
|
'description': 'New desc',
|
||||||
'name': 'New Name',
|
'uuid': 'should_be_filtered', # immutable
|
||||||
'description': 'New desc',
|
})
|
||||||
'uuid': 'should_be_filtered', # immutable
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that only mutable fields were passed to update
|
# Check that only mutable fields were passed to update
|
||||||
call_args = mock_app.persistence_mgr.execute_async.call_args
|
call_args = mock_app.persistence_mgr.execute_async.call_args
|
||||||
@@ -284,7 +288,9 @@ class TestListKnowledgeEngines:
|
|||||||
"""Test that it returns empty list and logs warning on exception."""
|
"""Test that it returns empty list and logs warning on exception."""
|
||||||
knowledge_module = get_knowledge_service_module()
|
knowledge_module = get_knowledge_service_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error'))
|
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||||
|
side_effect=Exception('Connection error')
|
||||||
|
)
|
||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
result = await service.list_knowledge_engines()
|
result = await service.list_knowledge_engines()
|
||||||
@@ -380,10 +386,12 @@ class TestGetEngineSchemas:
|
|||||||
"""Test that it returns empty dict and logs warning on exception."""
|
"""Test that it returns empty dict and logs warning on exception."""
|
||||||
knowledge_module = get_knowledge_service_module()
|
knowledge_module = get_knowledge_service_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error'))
|
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||||
|
side_effect=Exception('Plugin error')
|
||||||
|
)
|
||||||
|
|
||||||
service = knowledge_module.KnowledgeService(mock_app)
|
service = knowledge_module.KnowledgeService(mock_app)
|
||||||
result = await service.get_engine_creation_schema('author/engine')
|
result = await service.get_engine_creation_schema('author/engine')
|
||||||
|
|
||||||
assert result == {}
|
assert result == {}
|
||||||
mock_app.logger.warning.assert_called_once()
|
mock_app.logger.warning.assert_called_once()
|
||||||
@@ -174,7 +174,9 @@ class TestMaintenanceServiceGetStorageAnalysis:
|
|||||||
# Setup
|
# Setup
|
||||||
ap = SimpleNamespace()
|
ap = SimpleNamespace()
|
||||||
ap.instance_config = SimpleNamespace()
|
ap.instance_config = SimpleNamespace()
|
||||||
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
|
ap.instance_config.data = {
|
||||||
|
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
|
||||||
|
}
|
||||||
ap.persistence_mgr = SimpleNamespace()
|
ap.persistence_mgr = SimpleNamespace()
|
||||||
ap.logger = SimpleNamespace()
|
ap.logger = SimpleNamespace()
|
||||||
ap.logger.warning = Mock()
|
ap.logger.warning = Mock()
|
||||||
@@ -290,8 +292,12 @@ class TestMaintenanceServiceGetStorageAnalysis:
|
|||||||
service._file_count = Mock(return_value=0)
|
service._file_count = Mock(return_value=0)
|
||||||
service._monitoring_counts = AsyncMock(return_value={})
|
service._monitoring_counts = AsyncMock(return_value={})
|
||||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||||
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}])
|
service._expired_uploaded_candidates = AsyncMock(return_value=[
|
||||||
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}])
|
{'key': 'old_file', 'size_bytes': 100}
|
||||||
|
])
|
||||||
|
service._expired_log_candidates = Mock(return_value=[
|
||||||
|
{'name': 'old_log', 'size_bytes': 50}
|
||||||
|
])
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await service.get_storage_analysis()
|
result = await service.get_storage_analysis()
|
||||||
@@ -361,7 +367,6 @@ class TestMaintenanceServiceBinaryStorageStats:
|
|||||||
size_result = _create_mock_result(scalar_value=5000)
|
size_result = _create_mock_result(scalar_value=5000)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -391,7 +396,6 @@ class TestMaintenanceServiceBinaryStorageStats:
|
|||||||
count_result = _create_mock_result(scalar_value=5)
|
count_result = _create_mock_result(scalar_value=5)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -817,4 +821,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
|
|||||||
result = service._expired_local_upload_candidates(7, include_paths=True)
|
result = service._expired_local_upload_candidates(7, include_paths=True)
|
||||||
|
|
||||||
# Verify - path included
|
# Verify - path included
|
||||||
assert 'path' in result[0]
|
assert 'path' in result[0]
|
||||||
@@ -186,7 +186,13 @@ class TestMCPServiceCreateMCPServer:
|
|||||||
ap = SimpleNamespace()
|
ap = SimpleNamespace()
|
||||||
ap.persistence_mgr = SimpleNamespace()
|
ap.persistence_mgr = SimpleNamespace()
|
||||||
ap.instance_config = SimpleNamespace()
|
ap.instance_config = SimpleNamespace()
|
||||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}}
|
ap.instance_config.data = {
|
||||||
|
'system': {
|
||||||
|
'limitation': {
|
||||||
|
'max_extensions': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ap.plugin_connector = SimpleNamespace()
|
ap.plugin_connector = SimpleNamespace()
|
||||||
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
|
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
|
||||||
|
|
||||||
@@ -246,7 +252,6 @@ class TestMCPServiceCreateMCPServer:
|
|||||||
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
|
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -356,7 +361,6 @@ class TestMCPServiceUpdateMCPServer:
|
|||||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -390,7 +394,6 @@ class TestMCPServiceUpdateMCPServer:
|
|||||||
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -429,7 +432,6 @@ class TestMCPServiceUpdateMCPServer:
|
|||||||
|
|
||||||
# Mock for: first select -> update -> second select (for updated server)
|
# Mock for: first select -> update -> second select (for updated server)
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -463,7 +465,6 @@ class TestMCPServiceUpdateMCPServer:
|
|||||||
|
|
||||||
# Mock execute for select and update
|
# Mock execute for select and update
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -498,7 +499,6 @@ class TestMCPServiceDeleteMCPServer:
|
|||||||
server = _create_mock_mcp_server(name='Server to Delete')
|
server = _create_mock_mcp_server(name='Server to Delete')
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -530,7 +530,6 @@ class TestMCPServiceDeleteMCPServer:
|
|||||||
server = _create_mock_mcp_server(name='Not in Sessions')
|
server = _create_mock_mcp_server(name='Not in Sessions')
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -560,7 +559,6 @@ class TestMCPServiceDeleteMCPServer:
|
|||||||
|
|
||||||
# No server found
|
# No server found
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -598,7 +596,9 @@ class TestMCPServiceTestMCPServer:
|
|||||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||||
|
|
||||||
ap.task_mgr = SimpleNamespace()
|
ap.task_mgr = SimpleNamespace()
|
||||||
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123))
|
ap.task_mgr.create_user_task = Mock(
|
||||||
|
return_value=SimpleNamespace(id=123)
|
||||||
|
)
|
||||||
|
|
||||||
service = MCPService(ap)
|
service = MCPService(ap)
|
||||||
|
|
||||||
@@ -634,7 +634,9 @@ class TestMCPServiceTestMCPServer:
|
|||||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
||||||
|
|
||||||
ap.task_mgr = SimpleNamespace()
|
ap.task_mgr = SimpleNamespace()
|
||||||
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456))
|
ap.task_mgr.create_user_task = Mock(
|
||||||
|
return_value=SimpleNamespace(id=456)
|
||||||
|
)
|
||||||
|
|
||||||
service = MCPService(ap)
|
service = MCPService(ap)
|
||||||
|
|
||||||
@@ -643,4 +645,4 @@ class TestMCPServiceTestMCPServer:
|
|||||||
|
|
||||||
# Verify - load_mcp_server called
|
# Verify - load_mcp_server called
|
||||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
||||||
assert task_id == 456
|
assert task_id == 456
|
||||||
@@ -167,7 +167,6 @@ class TestLLMModelsServiceGetLLMModels:
|
|||||||
mock_provider_result = _create_mock_result([])
|
mock_provider_result = _create_mock_result([])
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
return mock_result if call_count == 0 else mock_provider_result
|
return mock_result if call_count == 0 else mock_provider_result
|
||||||
|
|
||||||
@@ -201,7 +200,6 @@ class TestLLMModelsServiceGetLLMModels:
|
|||||||
mock_provider_result = _create_mock_result([provider])
|
mock_provider_result = _create_mock_result([provider])
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -241,7 +239,6 @@ class TestLLMModelsServiceGetLLMModels:
|
|||||||
mock_provider_result = _create_mock_result([provider])
|
mock_provider_result = _create_mock_result([provider])
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -282,7 +279,6 @@ class TestLLMModelsServiceGetLLMModel:
|
|||||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -341,7 +337,9 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
|
|||||||
|
|
||||||
mock_result = _create_mock_result([model1, model2])
|
mock_result = _create_mock_result([model1, model2])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'model-1', 'name': 'Model 1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = LLMModelsService(ap)
|
service = LLMModelsService(ap)
|
||||||
|
|
||||||
@@ -373,14 +371,12 @@ class TestLLMModelsServiceCreateLLMModel:
|
|||||||
service = LLMModelsService(ap)
|
service = LLMModelsService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
model_uuid = await service.create_llm_model(
|
model_uuid = await service.create_llm_model({
|
||||||
{
|
'name': 'New LLM',
|
||||||
'name': 'New LLM',
|
'provider_uuid': 'provider-uuid',
|
||||||
'provider_uuid': 'provider-uuid',
|
'abilities': [],
|
||||||
'abilities': [],
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert model_uuid is not None
|
assert model_uuid is not None
|
||||||
@@ -404,16 +400,13 @@ class TestLLMModelsServiceCreateLLMModel:
|
|||||||
service = LLMModelsService(ap)
|
service = LLMModelsService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
model_uuid = await service.create_llm_model(
|
model_uuid = await service.create_llm_model({
|
||||||
{
|
'uuid': 'preserved-uuid',
|
||||||
'uuid': 'preserved-uuid',
|
'name': 'Preserved UUID Model',
|
||||||
'name': 'Preserved UUID Model',
|
'provider_uuid': 'provider-uuid',
|
||||||
'provider_uuid': 'provider-uuid',
|
'abilities': [],
|
||||||
'abilities': [],
|
'extra_args': {},
|
||||||
'extra_args': {},
|
}, preserve_uuid=True)
|
||||||
},
|
|
||||||
preserve_uuid=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert model_uuid == 'preserved-uuid'
|
assert model_uuid == 'preserved-uuid'
|
||||||
@@ -466,14 +459,12 @@ class TestLLMModelsServiceCreateLLMModel:
|
|||||||
|
|
||||||
# Execute & Verify
|
# Execute & Verify
|
||||||
with pytest.raises(Exception, match='provider not found'):
|
with pytest.raises(Exception, match='provider not found'):
|
||||||
await service.create_llm_model(
|
await service.create_llm_model({
|
||||||
{
|
'name': 'No Provider Model',
|
||||||
'name': 'No Provider Model',
|
'provider_uuid': 'nonexistent-provider',
|
||||||
'provider_uuid': 'nonexistent-provider',
|
'abilities': [],
|
||||||
'abilities': [],
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_create_llm_model_with_provider_data(self):
|
async def test_create_llm_model_with_provider_data(self):
|
||||||
"""Creates provider when provider data provided."""
|
"""Creates provider when provider data provided."""
|
||||||
@@ -499,18 +490,16 @@ class TestLLMModelsServiceCreateLLMModel:
|
|||||||
service = LLMModelsService(ap)
|
service = LLMModelsService(ap)
|
||||||
|
|
||||||
# Execute - with provider data (no UUID)
|
# Execute - with provider data (no UUID)
|
||||||
result_uuid = await service.create_llm_model(
|
result_uuid = await service.create_llm_model({
|
||||||
{
|
'name': 'Model with New Provider',
|
||||||
'name': 'Model with New Provider',
|
'provider': {
|
||||||
'provider': {
|
'requester': 'openai',
|
||||||
'requester': 'openai',
|
'base_url': 'https://api.openai.com',
|
||||||
'base_url': 'https://api.openai.com',
|
'api_keys': ['key'],
|
||||||
'api_keys': ['key'],
|
},
|
||||||
},
|
'abilities': [],
|
||||||
'abilities': [],
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify - provider_service was called and UUID generated
|
# Verify - provider_service was called and UUID generated
|
||||||
ap.provider_service.find_or_create_provider.assert_called_once()
|
ap.provider_service.find_or_create_provider.assert_called_once()
|
||||||
@@ -536,14 +525,11 @@ class TestLLMModelsServiceUpdateLLMModel:
|
|||||||
service = LLMModelsService(ap)
|
service = LLMModelsService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await service.update_llm_model(
|
await service.update_llm_model('existing-uuid', {
|
||||||
'existing-uuid',
|
'uuid': 'should-be-removed',
|
||||||
{
|
'name': 'Updated Name',
|
||||||
'uuid': 'should-be-removed',
|
'provider_uuid': 'provider-uuid',
|
||||||
'name': 'Updated Name',
|
})
|
||||||
'provider_uuid': 'provider-uuid',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify - remove and load called
|
# Verify - remove and load called
|
||||||
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
|
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
|
||||||
@@ -563,13 +549,10 @@ class TestLLMModelsServiceUpdateLLMModel:
|
|||||||
|
|
||||||
# Execute & Verify
|
# Execute & Verify
|
||||||
with pytest.raises(Exception, match='provider not found'):
|
with pytest.raises(Exception, match='provider not found'):
|
||||||
await service.update_llm_model(
|
await service.update_llm_model('model-uuid', {
|
||||||
'model-uuid',
|
'name': 'Update',
|
||||||
{
|
'provider_uuid': 'nonexistent-provider',
|
||||||
'name': 'Update',
|
})
|
||||||
'provider_uuid': 'nonexistent-provider',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_update_llm_model_reloads_context_length_as_column(self):
|
async def test_update_llm_model_reloads_context_length_as_column(self):
|
||||||
"""Updates runtime model with context_length outside extra_args."""
|
"""Updates runtime model with context_length outside extra_args."""
|
||||||
@@ -635,7 +618,9 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
|
|||||||
|
|
||||||
mock_result = _create_mock_result([])
|
mock_result = _create_mock_result([])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
|
||||||
|
)
|
||||||
|
|
||||||
service = EmbeddingModelsService(ap)
|
service = EmbeddingModelsService(ap)
|
||||||
|
|
||||||
@@ -658,7 +643,6 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
|
|||||||
mock_provider_result = _create_mock_result([provider])
|
mock_provider_result = _create_mock_result([provider])
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -699,7 +683,6 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
|
|||||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -759,13 +742,11 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
|||||||
service = EmbeddingModelsService(ap)
|
service = EmbeddingModelsService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
model_uuid = await service.create_embedding_model(
|
model_uuid = await service.create_embedding_model({
|
||||||
{
|
'name': 'New Embedding',
|
||||||
'name': 'New Embedding',
|
'provider_uuid': 'provider-uuid',
|
||||||
'provider_uuid': 'provider-uuid',
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert model_uuid is not None
|
assert model_uuid is not None
|
||||||
@@ -786,13 +767,11 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
|||||||
|
|
||||||
# Execute & Verify
|
# Execute & Verify
|
||||||
with pytest.raises(Exception, match='provider not found'):
|
with pytest.raises(Exception, match='provider not found'):
|
||||||
await service.create_embedding_model(
|
await service.create_embedding_model({
|
||||||
{
|
'name': 'No Provider Embedding',
|
||||||
'name': 'No Provider Embedding',
|
'provider_uuid': 'nonexistent',
|
||||||
'provider_uuid': 'nonexistent',
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
|
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
|
||||||
@@ -850,7 +829,6 @@ class TestRerankModelsServiceGetRerankModels:
|
|||||||
mock_provider_result = _create_mock_result([provider])
|
mock_provider_result = _create_mock_result([provider])
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -891,7 +869,6 @@ class TestRerankModelsServiceGetRerankModel:
|
|||||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -951,13 +928,11 @@ class TestRerankModelsServiceCreateRerankModel:
|
|||||||
service = RerankModelsService(ap)
|
service = RerankModelsService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
model_uuid = await service.create_rerank_model(
|
model_uuid = await service.create_rerank_model({
|
||||||
{
|
'name': 'New Rerank',
|
||||||
'name': 'New Rerank',
|
'provider_uuid': 'provider-uuid',
|
||||||
'provider_uuid': 'provider-uuid',
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert model_uuid is not None
|
assert model_uuid is not None
|
||||||
@@ -977,13 +952,11 @@ class TestRerankModelsServiceCreateRerankModel:
|
|||||||
|
|
||||||
# Execute & Verify
|
# Execute & Verify
|
||||||
with pytest.raises(Exception, match='provider not found'):
|
with pytest.raises(Exception, match='provider not found'):
|
||||||
await service.create_rerank_model(
|
await service.create_rerank_model({
|
||||||
{
|
'name': 'No Provider Rerank',
|
||||||
'name': 'No Provider Rerank',
|
'provider_uuid': 'nonexistent',
|
||||||
'provider_uuid': 'nonexistent',
|
'extra_args': {},
|
||||||
'extra_args': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRerankModelsServiceDeleteRerankModel:
|
class TestRerankModelsServiceDeleteRerankModel:
|
||||||
@@ -1022,7 +995,9 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
|
|||||||
|
|
||||||
mock_result = _create_mock_result([model1, model2])
|
mock_result = _create_mock_result([model1, model2])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = EmbeddingModelsService(ap)
|
service = EmbeddingModelsService(ap)
|
||||||
|
|
||||||
@@ -1047,7 +1022,9 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
|
|||||||
|
|
||||||
mock_result = _create_mock_result([model1, model2])
|
mock_result = _create_mock_result([model1, model2])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = RerankModelsService(ap)
|
service = RerankModelsService(ap)
|
||||||
|
|
||||||
@@ -1065,10 +1042,14 @@ class TestValidateProviderSupports:
|
|||||||
def _make_ap(requester_name: str, support_type):
|
def _make_ap(requester_name: str, support_type):
|
||||||
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
|
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
|
||||||
manifest = SimpleNamespace(spec={'support_type': support_type})
|
manifest = SimpleNamespace(spec={'support_type': support_type})
|
||||||
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name))
|
runtime_provider = SimpleNamespace(
|
||||||
|
provider_entity=SimpleNamespace(requester=requester_name)
|
||||||
|
)
|
||||||
model_mgr = SimpleNamespace(
|
model_mgr = SimpleNamespace(
|
||||||
provider_dict={'p1': runtime_provider},
|
provider_dict={'p1': runtime_provider},
|
||||||
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None,
|
get_available_requester_manifest_by_name=lambda name: manifest
|
||||||
|
if name == requester_name
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
return SimpleNamespace(model_mgr=model_mgr)
|
return SimpleNamespace(model_mgr=model_mgr)
|
||||||
|
|
||||||
@@ -1085,7 +1066,9 @@ class TestValidateProviderSupports:
|
|||||||
async def test_allows_when_support_type_missing(self):
|
async def test_allows_when_support_type_missing(self):
|
||||||
# Manifest without support_type must not block (backward compatible)
|
# Manifest without support_type must not block (backward compatible)
|
||||||
manifest = SimpleNamespace(spec={})
|
manifest = SimpleNamespace(spec={})
|
||||||
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy'))
|
runtime_provider = SimpleNamespace(
|
||||||
|
provider_entity=SimpleNamespace(requester='legacy')
|
||||||
|
)
|
||||||
model_mgr = SimpleNamespace(
|
model_mgr = SimpleNamespace(
|
||||||
provider_dict={'p1': runtime_provider},
|
provider_dict={'p1': runtime_provider},
|
||||||
get_available_requester_manifest_by_name=lambda name: manifest,
|
get_available_requester_manifest_by_name=lambda name: manifest,
|
||||||
|
|||||||
@@ -215,7 +215,13 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
ap = SimpleNamespace()
|
ap = SimpleNamespace()
|
||||||
ap.persistence_mgr = SimpleNamespace()
|
ap.persistence_mgr = SimpleNamespace()
|
||||||
ap.instance_config = SimpleNamespace()
|
ap.instance_config = SimpleNamespace()
|
||||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
|
ap.instance_config.data = {
|
||||||
|
'system': {
|
||||||
|
'limitation': {
|
||||||
|
'max_pipelines': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ap.pipeline_mgr = SimpleNamespace()
|
ap.pipeline_mgr = SimpleNamespace()
|
||||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||||
ap.ver_mgr = SimpleNamespace()
|
ap.ver_mgr = SimpleNamespace()
|
||||||
@@ -223,7 +229,9 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
|
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
|
||||||
|
)
|
||||||
|
|
||||||
service = PipelineService(ap)
|
service = PipelineService(ap)
|
||||||
|
|
||||||
@@ -250,14 +258,14 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
# Mock persistence for insert
|
# Mock persistence for insert
|
||||||
ap.persistence_mgr.execute_async = AsyncMock()
|
ap.persistence_mgr.execute_async = AsyncMock()
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
|
||||||
|
)
|
||||||
|
|
||||||
# Mock the file read for default config - patch at the utils module level
|
# Mock the file read for default config - patch at the utils module level
|
||||||
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
|
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
|
||||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||||
with patch(
|
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
|
||||||
):
|
|
||||||
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
|
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
@@ -278,9 +286,7 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
service = PipelineService(ap)
|
service = PipelineService(ap)
|
||||||
service.get_pipelines = AsyncMock(return_value=[])
|
service.get_pipelines = AsyncMock(return_value=[])
|
||||||
service.get_pipeline = AsyncMock(
|
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
|
||||||
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
|
|
||||||
)
|
|
||||||
|
|
||||||
ap.persistence_mgr.execute_async = AsyncMock()
|
ap.persistence_mgr.execute_async = AsyncMock()
|
||||||
ap.persistence_mgr.serialize_model = Mock(
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
@@ -290,9 +296,7 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
# Mock the file read
|
# Mock the file read
|
||||||
default_config = {}
|
default_config = {}
|
||||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||||
with patch(
|
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
|
||||||
):
|
|
||||||
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
|
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
|
||||||
|
|
||||||
# Verify - execute was called
|
# Verify - execute was called
|
||||||
@@ -312,12 +316,10 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
service = PipelineService(ap)
|
service = PipelineService(ap)
|
||||||
service.get_pipelines = AsyncMock(return_value=[])
|
service.get_pipelines = AsyncMock(return_value=[])
|
||||||
service.get_pipeline = AsyncMock(
|
service.get_pipeline = AsyncMock(return_value={
|
||||||
return_value={
|
'uuid': 'new-uuid',
|
||||||
'uuid': 'new-uuid',
|
'extensions_preferences': {},
|
||||||
'extensions_preferences': {},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
insert_params = []
|
insert_params = []
|
||||||
|
|
||||||
@@ -337,9 +339,7 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
default_config = {}
|
default_config = {}
|
||||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||||
with patch(
|
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
|
||||||
):
|
|
||||||
await service.create_pipeline({'name': 'New Pipeline'})
|
await service.create_pipeline({'name': 'New Pipeline'})
|
||||||
|
|
||||||
assert len(insert_params) == 1
|
assert len(insert_params) == 1
|
||||||
@@ -353,7 +353,6 @@ class TestPipelineServiceCreatePipeline:
|
|||||||
|
|
||||||
class _MockResultWithBots:
|
class _MockResultWithBots:
|
||||||
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
|
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
|
||||||
|
|
||||||
def __init__(self, bots_list):
|
def __init__(self, bots_list):
|
||||||
self._bots_list = bots_list
|
self._bots_list = bots_list
|
||||||
|
|
||||||
@@ -429,7 +428,6 @@ class TestPipelineServiceUpdatePipeline:
|
|||||||
# 1. UPDATE (line 125) - returns Mock (no result needed)
|
# 1. UPDATE (line 125) - returns Mock (no result needed)
|
||||||
# 2. SELECT bots (line 136) - returns bot_result with .all()
|
# 2. SELECT bots (line 136) - returns bot_result with .all()
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -530,7 +528,13 @@ class TestPipelineServiceCopyPipeline:
|
|||||||
ap = SimpleNamespace()
|
ap = SimpleNamespace()
|
||||||
ap.persistence_mgr = SimpleNamespace()
|
ap.persistence_mgr = SimpleNamespace()
|
||||||
ap.instance_config = SimpleNamespace()
|
ap.instance_config = SimpleNamespace()
|
||||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
|
ap.instance_config.data = {
|
||||||
|
'system': {
|
||||||
|
'limitation': {
|
||||||
|
'max_pipelines': 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ap.pipeline_mgr = SimpleNamespace()
|
ap.pipeline_mgr = SimpleNamespace()
|
||||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||||
ap.ver_mgr = SimpleNamespace()
|
ap.ver_mgr = SimpleNamespace()
|
||||||
@@ -538,12 +542,10 @@ class TestPipelineServiceCopyPipeline:
|
|||||||
|
|
||||||
service = PipelineService(ap)
|
service = PipelineService(ap)
|
||||||
# Mock get_pipelines to return 2 pipelines
|
# Mock get_pipelines to return 2 pipelines
|
||||||
service.get_pipelines = AsyncMock(
|
service.get_pipelines = AsyncMock(return_value=[
|
||||||
return_value=[
|
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
||||||
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
||||||
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute & Verify
|
# Execute & Verify
|
||||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||||
@@ -640,7 +642,9 @@ class TestPipelineServiceCopyPipeline:
|
|||||||
service = PipelineService(ap)
|
service = PipelineService(ap)
|
||||||
service.get_pipelines = AsyncMock(return_value=[])
|
service.get_pipelines = AsyncMock(return_value=[])
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'uuid': 'copy-uuid', 'is_default': False}
|
||||||
|
)
|
||||||
|
|
||||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
||||||
|
|
||||||
@@ -677,10 +681,11 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||||
|
|
||||||
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []})
|
original_pipeline = _create_mock_pipeline(
|
||||||
|
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
|
||||||
|
)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -695,7 +700,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
'extensions_preferences': {
|
'extensions_preferences': {
|
||||||
'enable_all_plugins': False,
|
'enable_all_plugins': False,
|
||||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -706,7 +711,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
'extensions_preferences': {
|
'extensions_preferences': {
|
||||||
'enable_all_plugins': False,
|
'enable_all_plugins': False,
|
||||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -733,7 +738,6 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
original_pipeline = _create_mock_pipeline()
|
original_pipeline = _create_mock_pipeline()
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -748,7 +752,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
'extensions_preferences': {
|
'extensions_preferences': {
|
||||||
'enable_all_mcp_servers': False,
|
'enable_all_mcp_servers': False,
|
||||||
'mcp_servers': ['mcp-server-1'],
|
'mcp_servers': ['mcp-server-1'],
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -790,7 +794,6 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
|||||||
)
|
)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|||||||
@@ -245,14 +245,12 @@ class TestModelProviderServiceCreateProvider:
|
|||||||
service = ModelProviderService(ap)
|
service = ModelProviderService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
provider_uuid = await service.create_provider(
|
provider_uuid = await service.create_provider({
|
||||||
{
|
'name': 'New Provider',
|
||||||
'name': 'New Provider',
|
'requester': 'openai',
|
||||||
'requester': 'openai',
|
'base_url': 'https://api.openai.com',
|
||||||
'base_url': 'https://api.openai.com',
|
'api_keys': ['key'],
|
||||||
'api_keys': ['key'],
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify - UUID is generated
|
# Verify - UUID is generated
|
||||||
assert provider_uuid is not None
|
assert provider_uuid is not None
|
||||||
@@ -276,14 +274,12 @@ class TestModelProviderServiceCreateProvider:
|
|||||||
service = ModelProviderService(ap)
|
service = ModelProviderService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result_uuid = await service.create_provider(
|
result_uuid = await service.create_provider({
|
||||||
{
|
'name': 'Runtime Provider',
|
||||||
'name': 'Runtime Provider',
|
'requester': 'openai',
|
||||||
'requester': 'openai',
|
'base_url': 'https://api.openai.com',
|
||||||
'base_url': 'https://api.openai.com',
|
'api_keys': ['key'],
|
||||||
'api_keys': ['key'],
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify - provider added to runtime dict and UUID generated
|
# Verify - provider added to runtime dict and UUID generated
|
||||||
ap.model_mgr.load_provider.assert_called_once()
|
ap.model_mgr.load_provider.assert_called_once()
|
||||||
@@ -306,13 +302,10 @@ class TestModelProviderServiceUpdateProvider:
|
|||||||
service = ModelProviderService(ap)
|
service = ModelProviderService(ap)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await service.update_provider(
|
await service.update_provider('existing-uuid', {
|
||||||
'existing-uuid',
|
'uuid': 'should-be-removed', # Will be removed
|
||||||
{
|
'name': 'Updated Name',
|
||||||
'uuid': 'should-be-removed', # Will be removed
|
})
|
||||||
'name': 'Updated Name',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify - reload called
|
# Verify - reload called
|
||||||
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
|
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
|
||||||
@@ -371,7 +364,6 @@ class TestModelProviderServiceDeleteProvider:
|
|||||||
rerank_result.first = Mock(return_value=None)
|
rerank_result.first = Mock(return_value=None)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -404,7 +396,6 @@ class TestModelProviderServiceDeleteProvider:
|
|||||||
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
|
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -463,7 +454,6 @@ class TestModelProviderServiceGetProviderModelCounts:
|
|||||||
rerank_result.scalar = Mock(return_value=1)
|
rerank_result.scalar = Mock(return_value=1)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -647,7 +637,9 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
|
|||||||
await service.update_space_model_provider_api_keys('space-api-key')
|
await service.update_space_model_provider_api_keys('space-api-key')
|
||||||
|
|
||||||
# Verify - update and reload called for Space provider UUID
|
# Verify - update and reload called for Space provider UUID
|
||||||
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000')
|
ap.model_mgr.reload_provider.assert_called_once_with(
|
||||||
|
'00000000-0000-0000-0000-000000000000'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestModelProviderServiceScanProviderModels:
|
class TestModelProviderServiceScanProviderModels:
|
||||||
@@ -803,7 +795,9 @@ class TestModelProviderServiceScanProviderModels:
|
|||||||
runtime_provider.token_mgr = Mock()
|
runtime_provider.token_mgr = Mock()
|
||||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||||
runtime_provider.token_mgr.tokens = ['token']
|
runtime_provider.token_mgr.tokens = ['token']
|
||||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported'))
|
runtime_provider.requester.scan_models = AsyncMock(
|
||||||
|
side_effect=NotImplementedError('scan not supported')
|
||||||
|
)
|
||||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||||
|
|
||||||
service = ModelProviderService(ap)
|
service = ModelProviderService(ap)
|
||||||
@@ -854,7 +848,9 @@ class TestModelProviderServiceScanProviderModels:
|
|||||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||||
|
|
||||||
# Mock existing LLM model
|
# Mock existing LLM model
|
||||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}])
|
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
|
||||||
|
return_value=[{'name': 'Existing Model'}]
|
||||||
|
)
|
||||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||||
|
|
||||||
service = ModelProviderService(ap)
|
service = ModelProviderService(ap)
|
||||||
@@ -867,4 +863,4 @@ class TestModelProviderServiceScanProviderModels:
|
|||||||
assert existing_model['already_added'] is True
|
assert existing_model['already_added'] is True
|
||||||
|
|
||||||
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
|
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
|
||||||
assert new_model['already_added'] is False
|
assert new_model['already_added'] is False
|
||||||
@@ -393,16 +393,14 @@ class TestSpaceServiceRefreshToken:
|
|||||||
# Mock HTTP response
|
# Mock HTTP response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.json = AsyncMock(
|
mock_response.json = AsyncMock(return_value={
|
||||||
return_value={
|
'code': 0,
|
||||||
'code': 0,
|
'data': {
|
||||||
'data': {
|
'access_token': 'new_access_token',
|
||||||
'access_token': 'new_access_token',
|
'refresh_token': 'new_refresh_token',
|
||||||
'refresh_token': 'new_refresh_token',
|
'expires_in': 3600,
|
||||||
'expires_in': 3600,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
})
|
||||||
|
|
||||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||||
mock_session_obj = MagicMock()
|
mock_session_obj = MagicMock()
|
||||||
@@ -431,12 +429,10 @@ class TestSpaceServiceRefreshToken:
|
|||||||
# Mock HTTP response with error
|
# Mock HTTP response with error
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.json = AsyncMock(
|
mock_response.json = AsyncMock(return_value={
|
||||||
return_value={
|
'code': 1,
|
||||||
'code': 1,
|
'msg': 'Invalid refresh token',
|
||||||
'msg': 'Invalid refresh token',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
|
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
|
||||||
|
|
||||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||||
@@ -493,16 +489,14 @@ class TestSpaceServiceExchangeOAuthCode:
|
|||||||
# Mock HTTP response
|
# Mock HTTP response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.json = AsyncMock(
|
mock_response.json = AsyncMock(return_value={
|
||||||
return_value={
|
'code': 0,
|
||||||
'code': 0,
|
'data': {
|
||||||
'data': {
|
'access_token': 'new_access_token',
|
||||||
'access_token': 'new_access_token',
|
'refresh_token': 'new_refresh_token',
|
||||||
'refresh_token': 'new_refresh_token',
|
'expires_in': 3600,
|
||||||
'expires_in': 3600,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
})
|
||||||
|
|
||||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||||
mock_session_obj = MagicMock()
|
mock_session_obj = MagicMock()
|
||||||
@@ -561,15 +555,13 @@ class TestSpaceServiceGetUserInfoRaw:
|
|||||||
# Mock HTTP response
|
# Mock HTTP response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.json = AsyncMock(
|
mock_response.json = AsyncMock(return_value={
|
||||||
return_value={
|
'code': 0,
|
||||||
'code': 0,
|
'data': {
|
||||||
'data': {
|
'email': 'test@example.com',
|
||||||
'email': 'test@example.com',
|
'credits': 100,
|
||||||
'credits': 100,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
})
|
||||||
|
|
||||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||||
mock_session_obj = MagicMock()
|
mock_session_obj = MagicMock()
|
||||||
@@ -677,29 +669,27 @@ class TestSpaceServiceGetModels:
|
|||||||
# Mock HTTP response with proper model data matching SpaceModel schema
|
# Mock HTTP response with proper model data matching SpaceModel schema
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status = 200
|
mock_response.status = 200
|
||||||
mock_response.json = AsyncMock(
|
mock_response.json = AsyncMock(return_value={
|
||||||
return_value={
|
'code': 0,
|
||||||
'code': 0,
|
'data': {
|
||||||
'data': {
|
'models': [
|
||||||
'models': [
|
{
|
||||||
{
|
'uuid': 'uuid-1',
|
||||||
'uuid': 'uuid-1',
|
'model_id': 'model-1',
|
||||||
'model_id': 'model-1',
|
'provider': 'provider-1',
|
||||||
'provider': 'provider-1',
|
'category': 'chat',
|
||||||
'category': 'chat',
|
'status': 'active',
|
||||||
'status': 'active',
|
},
|
||||||
},
|
{
|
||||||
{
|
'uuid': 'uuid-2',
|
||||||
'uuid': 'uuid-2',
|
'model_id': 'model-2',
|
||||||
'model_id': 'model-2',
|
'provider': 'provider-2',
|
||||||
'provider': 'provider-2',
|
'category': 'chat',
|
||||||
'category': 'chat',
|
'status': 'active',
|
||||||
'status': 'active',
|
},
|
||||||
},
|
]
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
})
|
||||||
|
|
||||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||||
mock_session_obj = MagicMock()
|
mock_session_obj = MagicMock()
|
||||||
@@ -785,4 +775,4 @@ class TestSpaceServiceCreditsCache:
|
|||||||
# Verify - cache updated
|
# Verify - cache updated
|
||||||
assert result == 500
|
assert result == 500
|
||||||
assert 'test@example.com' in service._credits_cache
|
assert 'test@example.com' in service._credits_cache
|
||||||
assert service._credits_cache['test@example.com'][0] == 500
|
assert service._credits_cache['test@example.com'][0] == 500
|
||||||
@@ -495,7 +495,6 @@ class TestUserServiceCreateOrUpdateSpaceUser:
|
|||||||
|
|
||||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_get_by_space_uuid(uuid):
|
async def mock_get_by_space_uuid(uuid):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -566,7 +565,6 @@ class TestUserServiceCreateOrUpdateSpaceUser:
|
|||||||
|
|
||||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_get_by_space_uuid(uuid):
|
async def mock_get_by_space_uuid(uuid):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -607,4 +605,4 @@ class TestUserServiceCreateUserLock:
|
|||||||
|
|
||||||
# Verify lock exists
|
# Verify lock exists
|
||||||
assert hasattr(service, '_create_user_lock')
|
assert hasattr(service, '_create_user_lock')
|
||||||
assert service._create_user_lock is not None
|
assert service._create_user_lock is not None
|
||||||
@@ -132,7 +132,6 @@ class TestWebhookServiceCreateWebhook:
|
|||||||
|
|
||||||
# execute_async returns different results
|
# execute_async returns different results
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -182,7 +181,6 @@ class TestWebhookServiceCreateWebhook:
|
|||||||
)
|
)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -219,7 +217,6 @@ class TestWebhookServiceCreateWebhook:
|
|||||||
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
|
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_execute(query):
|
async def mock_execute(query):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
@@ -228,7 +225,9 @@ class TestWebhookServiceCreateWebhook:
|
|||||||
return _create_mock_result(first_item=created_webhook)
|
return _create_mock_result(first_item=created_webhook)
|
||||||
|
|
||||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||||
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False})
|
ap.persistence_mgr.serialize_model = Mock(
|
||||||
|
return_value={'id': 1, 'enabled': False}
|
||||||
|
)
|
||||||
|
|
||||||
service = WebhookService(ap)
|
service = WebhookService(ap)
|
||||||
|
|
||||||
@@ -504,4 +503,4 @@ class TestWebhookServiceGetEnabledWebhooks:
|
|||||||
result = await service.get_enabled_webhooks()
|
result = await service.get_enabled_webhooks()
|
||||||
|
|
||||||
# Verify - should be empty (SQL would filter disabled)
|
# Verify - should be empty (SQL would filter disabled)
|
||||||
assert result == []
|
assert result == []
|
||||||
@@ -407,9 +407,7 @@ def test_box_service_forced_template_ignores_pipeline_config():
|
|||||||
launcher_type='person',
|
launcher_type='person',
|
||||||
launcher_id='test_user',
|
launcher_id='test_user',
|
||||||
sender_id='test_user',
|
sender_id='test_user',
|
||||||
pipeline_config={
|
pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
|
||||||
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert service.resolve_box_session_id(query) == 'global'
|
assert service.resolve_box_session_id(query) == 'global'
|
||||||
@@ -1529,7 +1527,9 @@ class TestBuildSkillExtraMounts:
|
|||||||
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
|
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
|
||||||
]
|
]
|
||||||
# No skill is dropped, so no "missing" warning should be logged.
|
# No skill is dropped, so no "missing" warning should be logged.
|
||||||
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list)
|
assert not any(
|
||||||
|
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
|
||||||
|
)
|
||||||
|
|
||||||
def test_skips_skill_with_empty_package_root(self):
|
def test_skips_skill_with_empty_package_root(self):
|
||||||
logger = Mock()
|
logger = Mock()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
# Unit tests for command module
|
# Unit tests for command module
|
||||||
@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
|
|||||||
|
|
||||||
# Should yield CommandNotFoundError (no such command registered)
|
# Should yield CommandNotFoundError (no such command registered)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].error is not None
|
assert results[0].error is not None
|
||||||
@@ -197,7 +197,6 @@ class TestCommandOperatorBase:
|
|||||||
op = TestOperator(None)
|
op = TestOperator(None)
|
||||||
# Should not raise
|
# Should not raise
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(op.initialize())
|
asyncio.get_event_loop().run_until_complete(op.initialize())
|
||||||
|
|
||||||
def test_execute_is_abstract(self):
|
def test_execute_is_abstract(self):
|
||||||
@@ -300,4 +299,4 @@ class TestMultipleOperators:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
assert AdminOperator.lowest_privilege == 2
|
assert AdminOperator.lowest_privilege == 2
|
||||||
assert SubOperator.lowest_privilege == 1
|
assert SubOperator.lowest_privilege == 1
|
||||||
@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_yaml_loads(self, tmp_path):
|
async def test_valid_yaml_loads(self, tmp_path):
|
||||||
"""Valid YAML config should load correctly."""
|
"""Valid YAML config should load correctly."""
|
||||||
config_file = tmp_path / 'test_config.yaml'
|
config_file = tmp_path / "test_config.yaml"
|
||||||
|
|
||||||
# Write valid YAML
|
# Write valid YAML
|
||||||
config_file.write_text("""
|
config_file.write_text("""
|
||||||
@@ -51,7 +51,7 @@ settings:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_yaml_raises_error(self, tmp_path):
|
async def test_invalid_yaml_raises_error(self, tmp_path):
|
||||||
"""Invalid YAML should raise clear error."""
|
"""Invalid YAML should raise clear error."""
|
||||||
config_file = tmp_path / 'invalid.yaml'
|
config_file = tmp_path / "invalid.yaml"
|
||||||
|
|
||||||
# Write invalid YAML (unclosed bracket)
|
# Write invalid YAML (unclosed bracket)
|
||||||
config_file.write_text("""
|
config_file.write_text("""
|
||||||
@@ -67,13 +67,13 @@ settings:
|
|||||||
template_data={'name': 'default'},
|
template_data={'name': 'default'},
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match='Syntax error'):
|
with pytest.raises(Exception, match="Syntax error"):
|
||||||
await yaml_file.load(completion=False)
|
await yaml_file.load(completion=False)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_config_creates_from_template(self, tmp_path):
|
async def test_missing_config_creates_from_template(self, tmp_path):
|
||||||
"""Missing config file should be created from template."""
|
"""Missing config file should be created from template."""
|
||||||
config_file = tmp_path / 'new_config.yaml'
|
config_file = tmp_path / "new_config.yaml"
|
||||||
|
|
||||||
# File doesn't exist yet
|
# File doesn't exist yet
|
||||||
assert not config_file.exists()
|
assert not config_file.exists()
|
||||||
@@ -92,7 +92,7 @@ settings:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_template_completion(self, tmp_path):
|
async def test_template_completion(self, tmp_path):
|
||||||
"""Config should be completed with template defaults."""
|
"""Config should be completed with template defaults."""
|
||||||
config_file = tmp_path / 'partial.yaml'
|
config_file = tmp_path / "partial.yaml"
|
||||||
|
|
||||||
# Write partial config missing some template keys
|
# Write partial config missing some template keys
|
||||||
config_file.write_text("""
|
config_file.write_text("""
|
||||||
@@ -115,7 +115,7 @@ name: custom_name
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_yaml_save(self, tmp_path):
|
async def test_yaml_save(self, tmp_path):
|
||||||
"""YAML config can be saved."""
|
"""YAML config can be saved."""
|
||||||
config_file = tmp_path / 'save_test.yaml'
|
config_file = tmp_path / "save_test.yaml"
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(
|
yaml_file = YAMLConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -131,7 +131,7 @@ name: custom_name
|
|||||||
|
|
||||||
def test_yaml_save_sync(self, tmp_path):
|
def test_yaml_save_sync(self, tmp_path):
|
||||||
"""YAML config can be saved synchronously."""
|
"""YAML config can be saved synchronously."""
|
||||||
config_file = tmp_path / 'sync_save.yaml'
|
config_file = tmp_path / "sync_save.yaml"
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(
|
yaml_file = YAMLConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -151,18 +151,14 @@ class TestJSONConfigFile:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_json_loads(self, tmp_path):
|
async def test_valid_json_loads(self, tmp_path):
|
||||||
"""Valid JSON config should load correctly."""
|
"""Valid JSON config should load correctly."""
|
||||||
config_file = tmp_path / 'test_config.json'
|
config_file = tmp_path / "test_config.json"
|
||||||
|
|
||||||
# Write valid JSON
|
# Write valid JSON
|
||||||
config_file.write_text(
|
config_file.write_text(json.dumps({
|
||||||
json.dumps(
|
'name': 'json_app',
|
||||||
{
|
'version': '1.0',
|
||||||
'name': 'json_app',
|
'settings': {'debug': True, 'port': 8080},
|
||||||
'version': '1.0',
|
}))
|
||||||
'settings': {'debug': True, 'port': 8080},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
json_file = JSONConfigFile(
|
json_file = JSONConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -178,7 +174,7 @@ class TestJSONConfigFile:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_json_raises_error(self, tmp_path):
|
async def test_invalid_json_raises_error(self, tmp_path):
|
||||||
"""Invalid JSON should raise clear error."""
|
"""Invalid JSON should raise clear error."""
|
||||||
config_file = tmp_path / 'invalid.json'
|
config_file = tmp_path / "invalid.json"
|
||||||
|
|
||||||
# Write invalid JSON (missing closing brace)
|
# Write invalid JSON (missing closing brace)
|
||||||
config_file.write_text('{"name": "test", "unclosed": ')
|
config_file.write_text('{"name": "test", "unclosed": ')
|
||||||
@@ -188,13 +184,13 @@ class TestJSONConfigFile:
|
|||||||
template_data={'name': 'default'},
|
template_data={'name': 'default'},
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match='Syntax error'):
|
with pytest.raises(Exception, match="Syntax error"):
|
||||||
await json_file.load(completion=False)
|
await json_file.load(completion=False)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_missing_json_creates_from_template(self, tmp_path):
|
async def test_missing_json_creates_from_template(self, tmp_path):
|
||||||
"""Missing JSON file should be created from template."""
|
"""Missing JSON file should be created from template."""
|
||||||
config_file = tmp_path / 'new_config.json'
|
config_file = tmp_path / "new_config.json"
|
||||||
|
|
||||||
json_file = JSONConfigFile(
|
json_file = JSONConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -209,7 +205,7 @@ class TestJSONConfigFile:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_json_save(self, tmp_path):
|
async def test_json_save(self, tmp_path):
|
||||||
"""JSON config can be saved."""
|
"""JSON config can be saved."""
|
||||||
config_file = tmp_path / 'save_test.json'
|
config_file = tmp_path / "save_test.json"
|
||||||
|
|
||||||
json_file = JSONConfigFile(
|
json_file = JSONConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -230,7 +226,7 @@ class TestConfigManager:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_config_manager_load(self, tmp_path):
|
async def test_config_manager_load(self, tmp_path):
|
||||||
"""ConfigManager loads config correctly."""
|
"""ConfigManager loads config correctly."""
|
||||||
config_file = tmp_path / 'manager_test.yaml'
|
config_file = tmp_path / "manager_test.yaml"
|
||||||
config_file.write_text('name: managed_app\nversion: "1.0"\n')
|
config_file.write_text('name: managed_app\nversion: "1.0"\n')
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(
|
yaml_file = YAMLConfigFile(
|
||||||
@@ -247,7 +243,7 @@ class TestConfigManager:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_config_manager_dump(self, tmp_path):
|
async def test_config_manager_dump(self, tmp_path):
|
||||||
"""ConfigManager can dump config."""
|
"""ConfigManager can dump config."""
|
||||||
config_file = tmp_path / 'dump_test.yaml'
|
config_file = tmp_path / "dump_test.yaml"
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(
|
yaml_file = YAMLConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -264,7 +260,7 @@ class TestConfigManager:
|
|||||||
|
|
||||||
def test_config_manager_dump_sync(self, tmp_path):
|
def test_config_manager_dump_sync(self, tmp_path):
|
||||||
"""ConfigManager can dump config synchronously."""
|
"""ConfigManager can dump config synchronously."""
|
||||||
config_file = tmp_path / 'sync_dump.yaml'
|
config_file = tmp_path / "sync_dump.yaml"
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(
|
yaml_file = YAMLConfigFile(
|
||||||
str(config_file),
|
str(config_file),
|
||||||
@@ -284,7 +280,7 @@ class TestConfigExists:
|
|||||||
|
|
||||||
def test_yaml_exists_true(self, tmp_path):
|
def test_yaml_exists_true(self, tmp_path):
|
||||||
"""exists() returns True for existing file."""
|
"""exists() returns True for existing file."""
|
||||||
config_file = tmp_path / 'exists.yaml'
|
config_file = tmp_path / "exists.yaml"
|
||||||
config_file.write_text('name: test')
|
config_file.write_text('name: test')
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
||||||
@@ -292,14 +288,14 @@ class TestConfigExists:
|
|||||||
|
|
||||||
def test_yaml_exists_false(self, tmp_path):
|
def test_yaml_exists_false(self, tmp_path):
|
||||||
"""exists() returns False for missing file."""
|
"""exists() returns False for missing file."""
|
||||||
config_file = tmp_path / 'missing.yaml'
|
config_file = tmp_path / "missing.yaml"
|
||||||
|
|
||||||
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
||||||
assert yaml_file.exists() is False
|
assert yaml_file.exists() is False
|
||||||
|
|
||||||
def test_json_exists_true(self, tmp_path):
|
def test_json_exists_true(self, tmp_path):
|
||||||
"""exists() returns True for existing JSON file."""
|
"""exists() returns True for existing JSON file."""
|
||||||
config_file = tmp_path / 'exists.json'
|
config_file = tmp_path / "exists.json"
|
||||||
config_file.write_text('{}')
|
config_file.write_text('{}')
|
||||||
|
|
||||||
json_file = JSONConfigFile(str(config_file), template_data={})
|
json_file = JSONConfigFile(str(config_file), template_data={})
|
||||||
@@ -307,7 +303,7 @@ class TestConfigExists:
|
|||||||
|
|
||||||
def test_json_exists_false(self, tmp_path):
|
def test_json_exists_false(self, tmp_path):
|
||||||
"""exists() returns False for missing JSON file."""
|
"""exists() returns False for missing JSON file."""
|
||||||
config_file = tmp_path / 'missing.json'
|
config_file = tmp_path / "missing.json"
|
||||||
|
|
||||||
json_file = JSONConfigFile(str(config_file), template_data={})
|
json_file = JSONConfigFile(str(config_file), template_data={})
|
||||||
assert json_file.exists() is False
|
assert json_file.exists() is False
|
||||||
@@ -1 +1 @@
|
|||||||
"""Core module unit tests."""
|
"""Core module unit tests."""
|
||||||
@@ -4,7 +4,6 @@ Tests cover:
|
|||||||
- _get_positive_int_config() validation
|
- _get_positive_int_config() validation
|
||||||
- _get_positive_float_config() validation
|
- _get_positive_float_config() validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
@@ -189,4 +188,4 @@ class TestGetPositiveFloatConfig:
|
|||||||
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
|
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
|
||||||
|
|
||||||
assert result == 1.5
|
assert result == 1.5
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
@@ -27,7 +27,6 @@ class TestCheckDeps:
|
|||||||
from langbot.pkg.core.bootutils.deps import check_deps
|
from langbot.pkg.core.bootutils.deps import check_deps
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||||
|
|
||||||
assert result == []
|
assert result == []
|
||||||
@@ -47,7 +46,6 @@ class TestCheckDeps:
|
|||||||
from langbot.pkg.core.bootutils.deps import check_deps
|
from langbot.pkg.core.bootutils.deps import check_deps
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||||
|
|
||||||
assert 'requests' in result
|
assert 'requests' in result
|
||||||
@@ -63,7 +61,6 @@ class TestCheckDeps:
|
|||||||
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
|
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||||
|
|
||||||
# Should include all required_deps keys
|
# Should include all required_deps keys
|
||||||
@@ -110,7 +107,6 @@ class TestPrecheckPluginDeps:
|
|||||||
with patch('os.path.exists', return_value=False):
|
with patch('os.path.exists', return_value=False):
|
||||||
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
||||||
|
|
||||||
mock_install.assert_not_called()
|
mock_install.assert_not_called()
|
||||||
@@ -133,7 +129,6 @@ class TestPrecheckPluginDeps:
|
|||||||
with patch('os.listdir', side_effect=mock_listdir):
|
with patch('os.listdir', side_effect=mock_listdir):
|
||||||
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
||||||
|
|
||||||
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])
|
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Tests cover:
|
|||||||
- Dict type skipping
|
- Dict type skipping
|
||||||
- Missing key creation
|
- Missing key creation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -249,8 +248,15 @@ class TestApplyEnvOverridesToConfig:
|
|||||||
"""Test multiple env vars applied in order."""
|
"""Test multiple env vars applied in order."""
|
||||||
load_config = get_load_config_module()
|
load_config = get_load_config_module()
|
||||||
|
|
||||||
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}}
|
cfg = {
|
||||||
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'}
|
'system': {'name': 'default', 'enable': True},
|
||||||
|
'concurrency': {'pipeline': 5}
|
||||||
|
}
|
||||||
|
env = {
|
||||||
|
'SYSTEM__NAME': 'custom',
|
||||||
|
'SYSTEM__ENABLE': 'false',
|
||||||
|
'CONCURRENCY__PIPELINE': '10'
|
||||||
|
}
|
||||||
|
|
||||||
with patch.dict(os.environ, env, clear=True):
|
with patch.dict(os.environ, env, clear=True):
|
||||||
result = load_config._apply_env_overrides_to_config(cfg)
|
result = load_config._apply_env_overrides_to_config(cfg)
|
||||||
@@ -281,4 +287,4 @@ class TestApplyEnvOverridesToConfig:
|
|||||||
with patch.dict(os.environ, env, clear=True):
|
with patch.dict(os.environ, env, clear=True):
|
||||||
result = load_config._apply_env_overrides_to_config(cfg)
|
result = load_config._apply_env_overrides_to_config(cfg)
|
||||||
|
|
||||||
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||||
@@ -175,4 +175,4 @@ class TestPreregisteredStages:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
for key in preregistered_stages:
|
for key in preregistered_stages:
|
||||||
assert isinstance(key, str)
|
assert isinstance(key, str)
|
||||||
@@ -7,7 +7,6 @@ Tests cover:
|
|||||||
|
|
||||||
Note: Uses import_isolation to break circular import chains.
|
Note: Uses import_isolation to break circular import chains.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -20,17 +19,15 @@ from typing import Generator
|
|||||||
|
|
||||||
class MockLifecycleControlScopeEnum:
|
class MockLifecycleControlScopeEnum:
|
||||||
"""Mock enum value for LifecycleControlScope with .value attribute."""
|
"""Mock enum value for LifecycleControlScope with .value attribute."""
|
||||||
|
|
||||||
def __init__(self, value: str):
|
def __init__(self, value: str):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'LifecycleControlScope.{self.value.upper()}'
|
return f"LifecycleControlScope.{self.value.upper()}"
|
||||||
|
|
||||||
|
|
||||||
class MockLifecycleControlScope:
|
class MockLifecycleControlScope:
|
||||||
"""Mock enum for LifecycleControlScope."""
|
"""Mock enum for LifecycleControlScope."""
|
||||||
|
|
||||||
APPLICATION = MockLifecycleControlScopeEnum('application')
|
APPLICATION = MockLifecycleControlScopeEnum('application')
|
||||||
PLATFORM = MockLifecycleControlScopeEnum('platform')
|
PLATFORM = MockLifecycleControlScopeEnum('platform')
|
||||||
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
|
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
|
||||||
@@ -43,17 +40,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
|||||||
# Mock modules that cause circular imports
|
# Mock modules that cause circular imports
|
||||||
mock_entities = MagicMock()
|
mock_entities = MagicMock()
|
||||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||||
|
|
||||||
mock_app = MagicMock()
|
mock_app = MagicMock()
|
||||||
|
|
||||||
mock_importutil = MagicMock()
|
mock_importutil = MagicMock()
|
||||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||||
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
||||||
|
|
||||||
mock_http_controller = MagicMock()
|
mock_http_controller = MagicMock()
|
||||||
|
|
||||||
mock_rag_mgr = MagicMock()
|
mock_rag_mgr = MagicMock()
|
||||||
|
|
||||||
mocks = {
|
mocks = {
|
||||||
'langbot.pkg.core.entities': mock_entities,
|
'langbot.pkg.core.entities': mock_entities,
|
||||||
'langbot.pkg.core.app': mock_app,
|
'langbot.pkg.core.app': mock_app,
|
||||||
@@ -61,26 +58,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
|||||||
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
|
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
|
||||||
'langbot.pkg.utils.importutil': mock_importutil,
|
'langbot.pkg.utils.importutil': mock_importutil,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save original state
|
# Save original state
|
||||||
saved = {}
|
saved = {}
|
||||||
for name in mocks:
|
for name in mocks:
|
||||||
if name in sys.modules:
|
if name in sys.modules:
|
||||||
saved[name] = sys.modules[name]
|
saved[name] = sys.modules[name]
|
||||||
|
|
||||||
# Clear taskmgr to force re-import
|
# Clear taskmgr to force re-import
|
||||||
taskmgr_name = 'langbot.pkg.core.taskmgr'
|
taskmgr_name = 'langbot.pkg.core.taskmgr'
|
||||||
if taskmgr_name in sys.modules:
|
if taskmgr_name in sys.modules:
|
||||||
saved[taskmgr_name] = sys.modules[taskmgr_name]
|
saved[taskmgr_name] = sys.modules[taskmgr_name]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply mocks
|
# Apply mocks
|
||||||
for name, module in mocks.items():
|
for name, module in mocks.items():
|
||||||
sys.modules[name] = module
|
sys.modules[name] = module
|
||||||
|
|
||||||
# Clear taskmgr
|
# Clear taskmgr
|
||||||
sys.modules.pop(taskmgr_name, None)
|
sys.modules.pop(taskmgr_name, None)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# Restore
|
# Restore
|
||||||
@@ -89,7 +86,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
|||||||
sys.modules[name] = saved[name]
|
sys.modules[name] = saved[name]
|
||||||
else:
|
else:
|
||||||
sys.modules.pop(name, None)
|
sys.modules.pop(name, None)
|
||||||
|
|
||||||
if taskmgr_name in saved:
|
if taskmgr_name in saved:
|
||||||
sys.modules[taskmgr_name] = saved[taskmgr_name]
|
sys.modules[taskmgr_name] = saved[taskmgr_name]
|
||||||
else:
|
else:
|
||||||
@@ -100,7 +97,6 @@ def get_taskmgr_classes():
|
|||||||
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
|
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
|
||||||
with isolated_taskmgr_import():
|
with isolated_taskmgr_import():
|
||||||
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
|
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
|
||||||
|
|
||||||
return TaskContext, TaskWrapper, AsyncTaskManager
|
return TaskContext, TaskWrapper, AsyncTaskManager
|
||||||
|
|
||||||
|
|
||||||
@@ -198,10 +194,9 @@ class TestTaskContext:
|
|||||||
"""Test TaskContext.placeholder() returns singleton."""
|
"""Test TaskContext.placeholder() returns singleton."""
|
||||||
with isolated_taskmgr_import():
|
with isolated_taskmgr_import():
|
||||||
from langbot.pkg.core.taskmgr import TaskContext
|
from langbot.pkg.core.taskmgr import TaskContext
|
||||||
|
|
||||||
# Reset global placeholder
|
# Reset global placeholder
|
||||||
import langbot.pkg.core.taskmgr as taskmgr_module
|
import langbot.pkg.core.taskmgr as taskmgr_module
|
||||||
|
|
||||||
taskmgr_module.placeholder_context = None
|
taskmgr_module.placeholder_context = None
|
||||||
|
|
||||||
ctx1 = TaskContext.placeholder()
|
ctx1 = TaskContext.placeholder()
|
||||||
@@ -274,8 +269,7 @@ class TestTaskWrapper:
|
|||||||
return 'result'
|
return 'result'
|
||||||
|
|
||||||
wrapper = TaskWrapper(
|
wrapper = TaskWrapper(
|
||||||
mock_app,
|
mock_app, immediate_coro(),
|
||||||
immediate_coro(),
|
|
||||||
name='test_task',
|
name='test_task',
|
||||||
label='Test Task',
|
label='Test Task',
|
||||||
)
|
)
|
||||||
@@ -420,7 +414,7 @@ class TestAsyncTaskManager:
|
|||||||
async def test_cancel_by_scope(self):
|
async def test_cancel_by_scope(self):
|
||||||
"""Test cancel_by_scope cancels matching tasks."""
|
"""Test cancel_by_scope cancels matching tasks."""
|
||||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||||
|
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
manager = AsyncTaskManager(mock_app)
|
manager = AsyncTaskManager(mock_app)
|
||||||
|
|
||||||
@@ -428,10 +422,16 @@ class TestAsyncTaskManager:
|
|||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
# Create task with APPLICATION scope
|
# Create task with APPLICATION scope
|
||||||
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION])
|
w1 = manager.create_task(
|
||||||
|
long_coro(),
|
||||||
|
scopes=[MockLifecycleControlScope.APPLICATION]
|
||||||
|
)
|
||||||
|
|
||||||
# Create task with different scope
|
# Create task with different scope
|
||||||
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE])
|
w2 = manager.create_task(
|
||||||
|
long_coro(),
|
||||||
|
scopes=[MockLifecycleControlScope.PIPELINE]
|
||||||
|
)
|
||||||
|
|
||||||
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)
|
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)
|
||||||
|
|
||||||
|
|||||||
@@ -15,68 +15,68 @@ class TestI18nString:
|
|||||||
|
|
||||||
def test_create_with_english_only(self):
|
def test_create_with_english_only(self):
|
||||||
"""Create I18nString with only English."""
|
"""Create I18nString with only English."""
|
||||||
i18n = I18nString(en_US='Hello')
|
i18n = I18nString(en_US="Hello")
|
||||||
|
|
||||||
assert i18n.en_US == 'Hello'
|
assert i18n.en_US == "Hello"
|
||||||
assert i18n.zh_Hans is None
|
assert i18n.zh_Hans is None
|
||||||
|
|
||||||
def test_create_with_multiple_languages(self):
|
def test_create_with_multiple_languages(self):
|
||||||
"""Create I18nString with multiple languages."""
|
"""Create I18nString with multiple languages."""
|
||||||
i18n = I18nString(
|
i18n = I18nString(
|
||||||
en_US='Hello',
|
en_US="Hello",
|
||||||
zh_Hans='你好',
|
zh_Hans="你好",
|
||||||
zh_Hant='你好',
|
zh_Hant="你好",
|
||||||
ja_JP='こんにちは',
|
ja_JP="こんにちは",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert i18n.en_US == 'Hello'
|
assert i18n.en_US == "Hello"
|
||||||
assert i18n.zh_Hans == '你好'
|
assert i18n.zh_Hans == "你好"
|
||||||
assert i18n.zh_Hant == '你好'
|
assert i18n.zh_Hant == "你好"
|
||||||
assert i18n.ja_JP == 'こんにちは'
|
assert i18n.ja_JP == "こんにちは"
|
||||||
|
|
||||||
def test_to_dict_with_english_only(self):
|
def test_to_dict_with_english_only(self):
|
||||||
"""to_dict returns only non-None fields."""
|
"""to_dict returns only non-None fields."""
|
||||||
i18n = I18nString(en_US='Hello')
|
i18n = I18nString(en_US="Hello")
|
||||||
|
|
||||||
result = i18n.to_dict()
|
result = i18n.to_dict()
|
||||||
|
|
||||||
assert result == {'en_US': 'Hello'}
|
assert result == {"en_US": "Hello"}
|
||||||
|
|
||||||
def test_to_dict_with_multiple_languages(self):
|
def test_to_dict_with_multiple_languages(self):
|
||||||
"""to_dict returns all non-None fields."""
|
"""to_dict returns all non-None fields."""
|
||||||
i18n = I18nString(
|
i18n = I18nString(
|
||||||
en_US='Hello',
|
en_US="Hello",
|
||||||
zh_Hans='你好',
|
zh_Hans="你好",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = i18n.to_dict()
|
result = i18n.to_dict()
|
||||||
|
|
||||||
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'}
|
assert result == {"en_US": "Hello", "zh_Hans": "你好"}
|
||||||
|
|
||||||
def test_to_dict_excludes_none(self):
|
def test_to_dict_excludes_none(self):
|
||||||
"""to_dict excludes None values."""
|
"""to_dict excludes None values."""
|
||||||
i18n = I18nString(
|
i18n = I18nString(
|
||||||
en_US='Hello',
|
en_US="Hello",
|
||||||
zh_Hans=None,
|
zh_Hans=None,
|
||||||
ja_JP='こんにちは',
|
ja_JP="こんにちは",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = i18n.to_dict()
|
result = i18n.to_dict()
|
||||||
|
|
||||||
assert 'zh_Hans' not in result
|
assert "zh_Hans" not in result
|
||||||
assert 'en_US' in result
|
assert "en_US" in result
|
||||||
assert 'ja_JP' in result
|
assert "ja_JP" in result
|
||||||
|
|
||||||
def test_to_dict_all_languages(self):
|
def test_to_dict_all_languages(self):
|
||||||
"""to_dict with all supported languages."""
|
"""to_dict with all supported languages."""
|
||||||
i18n = I18nString(
|
i18n = I18nString(
|
||||||
en_US='Hello',
|
en_US="Hello",
|
||||||
zh_Hans='你好',
|
zh_Hans="你好",
|
||||||
zh_Hant='你好',
|
zh_Hant="你好",
|
||||||
ja_JP='こんにちは',
|
ja_JP="こんにちは",
|
||||||
th_TH='สวัสดี',
|
th_TH="สวัสดี",
|
||||||
vi_VN='Xin chào',
|
vi_VN="Xin chào",
|
||||||
es_ES='Hola',
|
es_ES="Hola",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = i18n.to_dict()
|
result = i18n.to_dict()
|
||||||
@@ -92,30 +92,30 @@ class TestMetadata:
|
|||||||
from langbot.pkg.discover.engine import I18nString
|
from langbot.pkg.discover.engine import I18nString
|
||||||
|
|
||||||
metadata = Metadata(
|
metadata = Metadata(
|
||||||
name='test-component',
|
name="test-component",
|
||||||
label=I18nString(en_US='Test Component'),
|
label=I18nString(en_US="Test Component"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metadata.name == 'test-component'
|
assert metadata.name == "test-component"
|
||||||
assert metadata.label.en_US == 'Test Component'
|
assert metadata.label.en_US == "Test Component"
|
||||||
|
|
||||||
def test_create_with_all_fields(self):
|
def test_create_with_all_fields(self):
|
||||||
"""Create Metadata with all optional fields."""
|
"""Create Metadata with all optional fields."""
|
||||||
from langbot.pkg.discover.engine import I18nString
|
from langbot.pkg.discover.engine import I18nString
|
||||||
|
|
||||||
metadata = Metadata(
|
metadata = Metadata(
|
||||||
name='test-component',
|
name="test-component",
|
||||||
label=I18nString(en_US='Test'),
|
label=I18nString(en_US="Test"),
|
||||||
description=I18nString(en_US='A test component'),
|
description=I18nString(en_US="A test component"),
|
||||||
version='1.0.0',
|
version="1.0.0",
|
||||||
icon='test-icon',
|
icon="test-icon",
|
||||||
author='Test Author',
|
author="Test Author",
|
||||||
repository='https://github.com/test/repo',
|
repository="https://github.com/test/repo",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metadata.version == '1.0.0'
|
assert metadata.version == "1.0.0"
|
||||||
assert metadata.icon == 'test-icon'
|
assert metadata.icon == "test-icon"
|
||||||
assert metadata.author == 'Test Author'
|
assert metadata.author == "Test Author"
|
||||||
|
|
||||||
|
|
||||||
class TestComponentManifest:
|
class TestComponentManifest:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Tests cover:
|
|||||||
|
|
||||||
Note: Uses import isolation to break circular import chains.
|
Note: Uses import isolation to break circular import chains.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
@@ -87,7 +86,6 @@ def get_database_module():
|
|||||||
"""Get database module with import isolation."""
|
"""Get database module with import isolation."""
|
||||||
with isolated_database_import():
|
with isolated_database_import():
|
||||||
from langbot.pkg.persistence import database
|
from langbot.pkg.persistence import database
|
||||||
|
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
||||||
@@ -200,4 +198,4 @@ class TestManagerClassDecorator:
|
|||||||
# Create instance to test method (with mock app)
|
# Create instance to test method (with mock app)
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
instance = ManagerWithMethods(mock_app)
|
instance = ManagerWithMethods(mock_app)
|
||||||
assert instance.custom_method() == 'test_value'
|
assert instance.custom_method() == 'test_value'
|
||||||
@@ -4,7 +4,6 @@ Tests cover:
|
|||||||
- execute_async() with mock database
|
- execute_async() with mock database
|
||||||
- get_db_engine() with mock database manager
|
- get_db_engine() with mock database manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -86,7 +85,7 @@ class TestExecuteAsync:
|
|||||||
mock_db.get_engine = Mock(return_value=mock_engine)
|
mock_db.get_engine = Mock(return_value=mock_engine)
|
||||||
mgr.db = mock_db
|
mgr.db = mock_db
|
||||||
|
|
||||||
result = await mgr.execute_async(sqlalchemy.text('SELECT 1'))
|
result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
|
||||||
|
|
||||||
# Verify result is the same object returned by execute
|
# Verify result is the same object returned by execute
|
||||||
assert result is mock_result
|
assert result is mock_result
|
||||||
@@ -153,4 +152,4 @@ class TestSerializeModelEdgeCases:
|
|||||||
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
|
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
|
||||||
|
|
||||||
# Result should be empty dict when all columns masked
|
# Result should be empty dict when all columns masked
|
||||||
assert result == {}
|
assert result == {}
|
||||||
@@ -5,7 +5,6 @@ Tests cover:
|
|||||||
- datetime conversion to isoformat
|
- datetime conversion to isoformat
|
||||||
- masked_columns exclusion
|
- masked_columns exclusion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class TestPendingMessage:
|
|||||||
"""PendingMessage should be created with correct fields."""
|
"""PendingMessage should be created with correct fields."""
|
||||||
aggregator = get_aggregator_module()
|
aggregator = get_aggregator_module()
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ class TestSessionBuffer:
|
|||||||
"""SessionBuffer should accept initial messages."""
|
"""SessionBuffer should accept initial messages."""
|
||||||
aggregator = get_aggregator_module()
|
aggregator = get_aggregator_module()
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
|
|||||||
|
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
|
|||||||
|
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain1 = text_chain('hello')
|
chain1 = text_chain("hello")
|
||||||
chain2 = text_chain('world')
|
chain2 = text_chain("world")
|
||||||
event = friend_message_event(chain1)
|
event = friend_message_event(chain1)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
|
|||||||
|
|
||||||
# Should contain both messages with separator
|
# Should contain both messages with separator
|
||||||
merged_str = str(merged.message_chain)
|
merged_str = str(merged.message_chain)
|
||||||
assert 'hello' in merged_str
|
assert "hello" in merged_str
|
||||||
assert 'world' in merged_str
|
assert "world" in merged_str
|
||||||
|
|
||||||
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
|
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
|
||||||
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
|
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
|
||||||
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain1 = text_chain('first')
|
chain1 = text_chain("first")
|
||||||
chain2 = text_chain('second')
|
chain2 = text_chain("second")
|
||||||
event = friend_message_event(chain1)
|
event = friend_message_event(chain1)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
|
|||||||
app = make_aggregator_app()
|
app = make_aggregator_app()
|
||||||
agg = aggregator.MessageAggregator(app)
|
agg = aggregator.MessageAggregator(app)
|
||||||
|
|
||||||
chain = text_chain('hello')
|
chain = text_chain("hello")
|
||||||
event = friend_message_event(chain)
|
event = friend_message_event(chain)
|
||||||
adapter = mock_adapter()
|
adapter = mock_adapter()
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from tests.factories import FakeApp
|
|||||||
|
|
||||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def mock_circular_import_chain():
|
def mock_circular_import_chain():
|
||||||
"""
|
"""
|
||||||
@@ -37,11 +36,9 @@ def mock_circular_import_chain():
|
|||||||
# Create a default runner that yields a simple response
|
# Create a default runner that yields a simple response
|
||||||
class DefaultRunner:
|
class DefaultRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
yield Message(role='assistant', content='fake response')
|
yield Message(role='assistant', content='fake response')
|
||||||
|
|
||||||
@@ -73,12 +70,9 @@ def mock_event_ctx():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def set_runner():
|
def set_runner():
|
||||||
"""Factory fixture to set a custom runner for tests."""
|
"""Factory fixture to set a custom runner for tests."""
|
||||||
|
|
||||||
def _set_runner(runner_class):
|
def _set_runner(runner_class):
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
|
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
|
||||||
|
|
||||||
return _set_runner
|
return _set_runner
|
||||||
|
|
||||||
|
|
||||||
@@ -93,7 +87,6 @@ def get_chat_handler():
|
|||||||
global _chat_handler_module
|
global _chat_handler_module
|
||||||
if _chat_handler_module is None:
|
if _chat_handler_module is None:
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
|
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
|
||||||
return _chat_handler_module
|
return _chat_handler_module
|
||||||
|
|
||||||
@@ -103,14 +96,12 @@ def get_entities():
|
|||||||
global _entities_module
|
global _entities_module
|
||||||
if _entities_module is None:
|
if _entities_module is None:
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||||
return _entities_module
|
return _entities_module
|
||||||
|
|
||||||
|
|
||||||
# ============== REAL ChatMessageHandler Tests ==============
|
# ============== REAL ChatMessageHandler Tests ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||||
class TestChatMessageHandlerReal:
|
class TestChatMessageHandlerReal:
|
||||||
"""Tests for real ChatMessageHandler class."""
|
"""Tests for real ChatMessageHandler class."""
|
||||||
@@ -197,11 +188,9 @@ class TestChatMessageHandlerReal:
|
|||||||
|
|
||||||
class QuickRunner:
|
class QuickRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
yield Message(role='assistant', content='ok')
|
yield Message(role='assistant', content='ok')
|
||||||
|
|
||||||
@@ -233,11 +222,9 @@ class TestChatMessageHandlerReal:
|
|||||||
|
|
||||||
class SingleRunner:
|
class SingleRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
yield Message(role='assistant', content='response')
|
yield Message(role='assistant', content='response')
|
||||||
|
|
||||||
@@ -275,11 +262,9 @@ class TestChatHandlerStreaming:
|
|||||||
|
|
||||||
class StreamRunner:
|
class StreamRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
yield MessageChunk(role='assistant', content='Hello', is_final=False)
|
yield MessageChunk(role='assistant', content='Hello', is_final=False)
|
||||||
yield MessageChunk(role='assistant', content=' World', is_final=True)
|
yield MessageChunk(role='assistant', content=' World', is_final=True)
|
||||||
@@ -318,19 +303,14 @@ class TestChatHandlerExceptions:
|
|||||||
|
|
||||||
query.pipeline_config = {
|
query.pipeline_config = {
|
||||||
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
|
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
|
||||||
'ai': {
|
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||||
'runner': {'runner': 'local-agent'},
|
|
||||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class FailingRunner:
|
class FailingRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
raise ValueError('API error')
|
raise ValueError('API error')
|
||||||
yield
|
yield
|
||||||
@@ -366,19 +346,14 @@ class TestChatHandlerExceptions:
|
|||||||
|
|
||||||
query.pipeline_config = {
|
query.pipeline_config = {
|
||||||
'output': {'misc': {'exception-handling': 'show-error'}},
|
'output': {'misc': {'exception-handling': 'show-error'}},
|
||||||
'ai': {
|
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||||
'runner': {'runner': 'local-agent'},
|
|
||||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ErrorRunner:
|
class ErrorRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
raise ValueError('Custom error')
|
raise ValueError('Custom error')
|
||||||
yield
|
yield
|
||||||
@@ -411,19 +386,14 @@ class TestChatHandlerExceptions:
|
|||||||
|
|
||||||
query.pipeline_config = {
|
query.pipeline_config = {
|
||||||
'output': {'misc': {'exception-handling': 'hide'}},
|
'output': {'misc': {'exception-handling': 'hide'}},
|
||||||
'ai': {
|
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||||
'runner': {'runner': 'local-agent'},
|
|
||||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class HideErrorRunner:
|
class HideErrorRunner:
|
||||||
name = 'local-agent'
|
name = 'local-agent'
|
||||||
|
|
||||||
def __init__(self, app, config):
|
def __init__(self, app, config):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def run(self, query):
|
async def run(self, query):
|
||||||
raise RuntimeError('hidden')
|
raise RuntimeError('hidden')
|
||||||
yield
|
yield
|
||||||
@@ -463,4 +433,4 @@ class TestChatHandlerHelper:
|
|||||||
chat = get_chat_handler()
|
chat = get_chat_handler()
|
||||||
handler = chat.ChatMessageHandler(fake_app)
|
handler = chat.ChatMessageHandler(fake_app)
|
||||||
result = handler.cut_str('first line\nsecond line')
|
result = handler.cut_str('first line\nsecond line')
|
||||||
assert '...' in result
|
assert '...' in result
|
||||||
@@ -67,11 +67,7 @@ def make_pipeline_config(**overrides):
|
|||||||
for key, value in overrides.items():
|
for key, value in overrides.items():
|
||||||
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
|
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
|
||||||
for sub_key, sub_value in value.items():
|
for sub_key, sub_value in value.items():
|
||||||
if (
|
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
|
||||||
sub_key in base_config[key]
|
|
||||||
and isinstance(base_config[key][sub_key], dict)
|
|
||||||
and isinstance(sub_value, dict)
|
|
||||||
):
|
|
||||||
base_config[key][sub_key].update(sub_value)
|
base_config[key][sub_key].update(sub_value)
|
||||||
else:
|
else:
|
||||||
base_config[key][sub_key] = sub_value
|
base_config[key][sub_key] = sub_value
|
||||||
@@ -145,7 +141,7 @@ class TestPreContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello world')
|
query = text_query("hello world")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -167,7 +163,7 @@ class TestPreContentFilter:
|
|||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
# Empty message chain
|
# Empty message chain
|
||||||
query = text_query('')
|
query = text_query("")
|
||||||
query.message_chain = platform_message.MessageChain([])
|
query.message_chain = platform_message.MessageChain([])
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
@@ -189,7 +185,7 @@ class TestPreContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query(' ') # Only whitespace
|
query = text_query(" ") # Only whitespace
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -238,7 +234,7 @@ class TestPreContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello world')
|
query = text_query("hello world")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -270,7 +266,7 @@ class TestContentIgnoreFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('/help me')
|
query = text_query("/help me")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -298,7 +294,7 @@ class TestContentIgnoreFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('http://example.com')
|
query = text_query("http://example.com")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -326,7 +322,7 @@ class TestContentIgnoreFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('normal message')
|
query = text_query("normal message")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -347,7 +343,7 @@ class TestContentIgnoreFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('/help me')
|
query = text_query("/help me")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
@@ -372,10 +368,12 @@ class TestPostContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
# Add a response message
|
# Add a response message
|
||||||
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
|
query.resp_messages = [
|
||||||
|
provider_message.Message(role='assistant', content='Hello back!')
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'PostContentFilterStage')
|
result = await stage.process(query, 'PostContentFilterStage')
|
||||||
|
|
||||||
@@ -400,9 +398,11 @@ class TestPostContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
|
query.resp_messages = [
|
||||||
|
provider_message.Message(role='assistant', content='Response')
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'PostContentFilterStage')
|
result = await stage.process(query, 'PostContentFilterStage')
|
||||||
|
|
||||||
@@ -422,7 +422,7 @@ class TestPostContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
# Non-string content - use model_construct to bypass validation
|
# Non-string content - use model_construct to bypass validation
|
||||||
# The actual content type could be a list of ContentElement objects
|
# The actual content type could be a list of ContentElement objects
|
||||||
@@ -450,9 +450,11 @@ class TestPostContentFilter:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_messages = [provider_message.Message(role='assistant', content='')]
|
query.resp_messages = [
|
||||||
|
provider_message.Message(role='assistant', content='')
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'PostContentFilterStage')
|
result = await stage.process(query, 'PostContentFilterStage')
|
||||||
|
|
||||||
@@ -474,7 +476,7 @@ class TestContentFilterStageInvalidName:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
|
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
|
||||||
@@ -504,7 +506,7 @@ class TestContentIgnoreFilterDirect:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('normal message without prefix')
|
query = text_query("normal message without prefix")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
result = await stage.process(query, 'PreContentFilterStage')
|
result = await stage.process(query, 'PreContentFilterStage')
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from tests.factories import FakeApp, command_query
|
|||||||
|
|
||||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def mock_circular_import_chain():
|
def mock_circular_import_chain():
|
||||||
"""
|
"""
|
||||||
@@ -57,7 +56,6 @@ def mock_event_ctx():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_execute_factory():
|
def mock_execute_factory():
|
||||||
"""Factory fixture to create mock cmd_mgr.execute generators."""
|
"""Factory fixture to create mock cmd_mgr.execute generators."""
|
||||||
|
|
||||||
def _create_execute(
|
def _create_execute(
|
||||||
text: str | None = 'ok',
|
text: str | None = 'ok',
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
@@ -73,9 +71,7 @@ def mock_execute_factory():
|
|||||||
ret.image_base64 = image_base64
|
ret.image_base64 = image_base64
|
||||||
ret.file_url = file_url
|
ret.file_url = file_url
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
return mock_execute
|
return mock_execute
|
||||||
|
|
||||||
return _create_execute
|
return _create_execute
|
||||||
|
|
||||||
|
|
||||||
@@ -90,7 +86,6 @@ def get_command_handler():
|
|||||||
global _command_handler_module
|
global _command_handler_module
|
||||||
if _command_handler_module is None:
|
if _command_handler_module is None:
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
|
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
|
||||||
return _command_handler_module
|
return _command_handler_module
|
||||||
|
|
||||||
@@ -100,14 +95,12 @@ def get_entities():
|
|||||||
global _entities_module
|
global _entities_module
|
||||||
if _entities_module is None:
|
if _entities_module is None:
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||||
return _entities_module
|
return _entities_module
|
||||||
|
|
||||||
|
|
||||||
# ============== REAL CommandHandler Tests ==============
|
# ============== REAL CommandHandler Tests ==============
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||||
class TestCommandHandlerReal:
|
class TestCommandHandlerReal:
|
||||||
"""Tests for real CommandHandler class."""
|
"""Tests for real CommandHandler class."""
|
||||||
@@ -134,7 +127,6 @@ class TestCommandHandlerReal:
|
|||||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
executed_commands = []
|
executed_commands = []
|
||||||
|
|
||||||
async def track_execute(command_text, full_command_text, query, session):
|
async def track_execute(command_text, full_command_text, query, session):
|
||||||
executed_commands.append(command_text)
|
executed_commands.append(command_text)
|
||||||
ret = Mock()
|
ret = Mock()
|
||||||
@@ -342,7 +334,8 @@ class TestCommandHandlerReal:
|
|||||||
command = get_command_handler()
|
command = get_command_handler()
|
||||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
fake_app.cmd_mgr.execute = mock_execute_factory(
|
fake_app.cmd_mgr.execute = mock_execute_factory(
|
||||||
text='Here is the image:', image_url='https://example.com/image.png'
|
text='Here is the image:',
|
||||||
|
image_url='https://example.com/image.png'
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = command.CommandHandler(fake_app)
|
handler = command.CommandHandler(fake_app)
|
||||||
@@ -400,4 +393,4 @@ class TestCommandHandlerHelper:
|
|||||||
command = get_command_handler()
|
command = get_command_handler()
|
||||||
handler = command.CommandHandler(fake_app)
|
handler = command.CommandHandler(fake_app)
|
||||||
result = handler.cut_str('first line\nsecond line')
|
result = handler.cut_str('first line\nsecond line')
|
||||||
assert '...' in result
|
assert '...' in result
|
||||||
@@ -126,9 +126,11 @@ class TestLongTextProcessStageProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
|
query.resp_message_chain = [
|
||||||
|
platform_message.MessageChain([platform_message.Plain(text="very long response")])
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'LongTextProcessStage')
|
result = await stage.process(query, 'LongTextProcessStage')
|
||||||
|
|
||||||
@@ -149,9 +151,11 @@ class TestLongTextProcessStageProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
|
query.resp_message_chain = [
|
||||||
|
platform_message.MessageChain([platform_message.Plain(text="short response")])
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'LongTextProcessStage')
|
result = await stage.process(query, 'LongTextProcessStage')
|
||||||
|
|
||||||
@@ -175,13 +179,14 @@ class TestLongTextProcessStageProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
# Non-Plain component (Image)
|
# Non-Plain component (Image)
|
||||||
query.resp_message_chain = [
|
query.resp_message_chain = [
|
||||||
platform_message.MessageChain(
|
platform_message.MessageChain([
|
||||||
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
|
platform_message.Plain(text="short"),
|
||||||
)
|
platform_message.Image(url="https://example.com/img.png")
|
||||||
|
])
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'LongTextProcessStage')
|
result = await stage.process(query, 'LongTextProcessStage')
|
||||||
@@ -208,7 +213,7 @@ class TestLongTextProcessStageProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -227,7 +232,7 @@ class TestLongTextProcessStageProcess:
|
|||||||
stage = longtext.LongTextProcessStage(app)
|
stage = longtext.LongTextProcessStage(app)
|
||||||
stage.strategy_impl = AsyncMock()
|
stage.strategy_impl = AsyncMock()
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
|
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -237,7 +242,6 @@ class TestLongTextProcessStageProcess:
|
|||||||
assert result.new_query is query
|
assert result.new_query is query
|
||||||
stage.strategy_impl.process.assert_not_called()
|
stage.strategy_impl.process.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
class TestForwardStrategy:
|
class TestForwardStrategy:
|
||||||
"""Tests for ForwardComponentStrategy."""
|
"""Tests for ForwardComponentStrategy."""
|
||||||
|
|
||||||
@@ -256,7 +260,7 @@ class TestForwardStrategy:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
# Create a mock adapter with bot_account_id
|
# Create a mock adapter with bot_account_id
|
||||||
mock_adapter = Mock()
|
mock_adapter = Mock()
|
||||||
@@ -264,8 +268,10 @@ class TestForwardStrategy:
|
|||||||
query.adapter = mock_adapter
|
query.adapter = mock_adapter
|
||||||
|
|
||||||
# Long text exceeding threshold
|
# Long text exceeding threshold
|
||||||
long_text = 'This is a very long response that exceeds the threshold'
|
long_text = "This is a very long response that exceeds the threshold"
|
||||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
|
query.resp_message_chain = [
|
||||||
|
platform_message.MessageChain([platform_message.Plain(text=long_text)])
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'LongTextProcessStage')
|
result = await stage.process(query, 'LongTextProcessStage')
|
||||||
|
|
||||||
@@ -291,13 +297,13 @@ class TestForwardStrategy:
|
|||||||
|
|
||||||
await strat.initialize()
|
await strat.initialize()
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = make_longtext_config()
|
query.pipeline_config = make_longtext_config()
|
||||||
mock_adapter = Mock()
|
mock_adapter = Mock()
|
||||||
mock_adapter.bot_account_id = '12345'
|
mock_adapter.bot_account_id = '12345'
|
||||||
query.adapter = mock_adapter
|
query.adapter = mock_adapter
|
||||||
|
|
||||||
components = await strat.process('test message', query)
|
components = await strat.process("test message", query)
|
||||||
|
|
||||||
assert len(components) == 1
|
assert len(components) == 1
|
||||||
assert isinstance(components[0], platform_message.Forward)
|
assert isinstance(components[0], platform_message.Forward)
|
||||||
@@ -320,12 +326,14 @@ class TestLongTextThreshold:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
|
|
||||||
# Text below threshold
|
# Text below threshold
|
||||||
short_text = 'x' * (threshold - 1)
|
short_text = "x" * (threshold - 1)
|
||||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
|
query.resp_message_chain = [
|
||||||
|
platform_message.MessageChain([platform_message.Plain(text=short_text)])
|
||||||
|
]
|
||||||
|
|
||||||
result = await stage.process(query, 'LongTextProcessStage')
|
result = await stage.process(query, 'LongTextProcessStage')
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
|
|||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
# Create query with 3 messages (within limit)
|
# Create query with 3 messages (within limit)
|
||||||
query = text_query('current message')
|
query = text_query("current message")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='message 1'),
|
provider_message.Message(role='user', content='message 1'),
|
||||||
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
|
|||||||
|
|
||||||
# Create query with many messages exceeding limit
|
# Create query with many messages exceeding limit
|
||||||
# 7 messages = 3 full rounds + 1 current user
|
# 7 messages = 3 full rounds + 1 current user
|
||||||
query = text_query('current message')
|
query = text_query("current message")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='message 1'),
|
provider_message.Message(role='user', content='message 1'),
|
||||||
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = []
|
query.messages = []
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='hello'),
|
provider_message.Message(role='user', content='hello'),
|
||||||
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('current')
|
query = text_query("current")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='user1'),
|
provider_message.Message(role='user', content='user1'),
|
||||||
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('current')
|
query = text_query("current")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='old1'),
|
provider_message.Message(role='user', content='old1'),
|
||||||
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
|
|||||||
trun = trun_cls(app)
|
trun = trun_cls(app)
|
||||||
break
|
break
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = make_truncate_config(max_round=3)
|
query.pipeline_config = make_truncate_config(max_round=3)
|
||||||
query.messages = [
|
query.messages = [
|
||||||
provider_message.Message(role='user', content='m1'),
|
provider_message.Message(role='user', content='m1'),
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = text_query('hello world')
|
query = text_query("hello world")
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = text_query('test message')
|
query = text_query("test message")
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
|
|
||||||
@@ -194,16 +194,13 @@ class TestPreProcessorImageSegment:
|
|||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
# Image query with base64
|
# Image query with base64
|
||||||
query = image_query(text='look at this', url=None)
|
query = image_query(text="look at this", url=None)
|
||||||
# Set base64 on the image component
|
# Set base64 on the image component
|
||||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
chain = platform_message.MessageChain([
|
||||||
chain = platform_message.MessageChain(
|
platform_message.Plain(text="look at this"),
|
||||||
[
|
platform_message.Image(base64="data:image/png;base64,abc123"),
|
||||||
platform_message.Plain(text='look at this'),
|
])
|
||||||
platform_message.Image(base64='data:image/png;base64,abc123'),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
query.message_chain = chain
|
query.message_chain = chain
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
@@ -241,7 +238,7 @@ class TestPreProcessorImageSegment:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = image_query(text='describe this')
|
query = image_query(text="describe this")
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
|
|
||||||
@@ -279,7 +276,7 @@ class TestPreProcessorModelSelection:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
# Set pipeline config with primary model
|
# Set pipeline config with primary model
|
||||||
query.pipeline_config = {
|
query.pipeline_config = {
|
||||||
@@ -338,7 +335,7 @@ class TestPreProcessorModelSelection:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
|
|
||||||
query.pipeline_config = {
|
query.pipeline_config = {
|
||||||
'ai': {
|
'ai': {
|
||||||
@@ -387,7 +384,7 @@ class TestPreProcessorVariables:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = text_query('hello', sender_id=67890)
|
query = text_query("hello", sender_id=67890)
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
|
|
||||||
@@ -424,7 +421,7 @@ class TestPreProcessorVariables:
|
|||||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||||
|
|
||||||
stage = preproc.PreProcessor(app)
|
stage = preproc.PreProcessor(app)
|
||||||
query = group_text_query('hello', group_id=99999)
|
query = group_text_query("hello", group_id=99999)
|
||||||
|
|
||||||
result = await stage.process(query, 'PreProcessor')
|
result = await stage.process(query, 'PreProcessor')
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
|
|||||||
'safety': {
|
'safety': {
|
||||||
'rate-limit': {
|
'rate-limit': {
|
||||||
'window-length': 60, # 60 seconds window
|
'window-length': 60, # 60 seconds window
|
||||||
'limitation': 10, # 10 requests per window
|
'limitation': 10, # 10 requests per window
|
||||||
'strategy': 'drop',
|
'strategy': 'drop',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,9 +75,11 @@ class TestFixedWindowAlgo:
|
|||||||
# Make requests within limit
|
# Make requests within limit
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
result = await algo.require_access(
|
result = await algo.require_access(
|
||||||
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
)
|
)
|
||||||
assert result is True, f'Request {i + 1} should be allowed'
|
assert result is True, f"Request {i+1} should be allowed"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
|
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||||
@@ -89,12 +91,20 @@ class TestFixedWindowAlgo:
|
|||||||
|
|
||||||
# Exhaust the limit
|
# Exhaust the limit
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
await algo.require_access(
|
||||||
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
|
|
||||||
# Next request should be denied
|
# Next request should be denied
|
||||||
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
result = await algo.require_access(
|
||||||
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
|
|
||||||
assert result is False, 'Request exceeding limit should be denied'
|
assert result is False, "Request exceeding limit should be denied"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
|
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||||
@@ -106,14 +116,20 @@ class TestFixedWindowAlgo:
|
|||||||
|
|
||||||
# Exhaust limit for session 1
|
# Exhaust limit for session 1
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
|
await algo.require_access(
|
||||||
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'session1'
|
||||||
|
)
|
||||||
|
|
||||||
# Session 2 should still have its own limit
|
# Session 2 should still have its own limit
|
||||||
result = await algo.require_access(
|
result = await algo.require_access(
|
||||||
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'session2'
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is True, 'Different session should have independent limit'
|
assert result is True, "Different session should have independent limit"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
|
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
|
||||||
@@ -134,11 +150,19 @@ class TestFixedWindowAlgo:
|
|||||||
await algo.initialize()
|
await algo.initialize()
|
||||||
|
|
||||||
# First request allowed
|
# First request allowed
|
||||||
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
|
result1 = await algo.require_access(
|
||||||
|
sample_query,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
assert result1 is True
|
assert result1 is True
|
||||||
|
|
||||||
# Second request denied
|
# Second request denied
|
||||||
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
|
result2 = await algo.require_access(
|
||||||
|
sample_query,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
assert result2 is False
|
assert result2 is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -150,7 +174,11 @@ class TestFixedWindowAlgo:
|
|||||||
await algo.initialize()
|
await algo.initialize()
|
||||||
|
|
||||||
# First request creates container
|
# First request creates container
|
||||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
await algo.require_access(
|
||||||
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
|
|
||||||
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
|
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
|
||||||
expected_key = 'LauncherTypes.PERSON_12345'
|
expected_key = 'LauncherTypes.PERSON_12345'
|
||||||
@@ -202,7 +230,7 @@ class TestFixedWindowAlgo:
|
|||||||
|
|
||||||
# New request should be allowed (new window)
|
# New request should be allowed (new window)
|
||||||
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
|
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
|
||||||
assert result is True, 'New window should allow new requests'
|
assert result is True, "New window should allow new requests"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
|
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
|
||||||
@@ -228,21 +256,29 @@ class TestFixedWindowAlgo:
|
|||||||
|
|
||||||
# First request allowed
|
# First request allowed
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
result1 = await algo.require_access(
|
||||||
|
sample_query,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'wait_test'
|
||||||
|
)
|
||||||
assert result1 is True
|
assert result1 is True
|
||||||
|
|
||||||
# Exhaust limit
|
# Exhaust limit
|
||||||
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
||||||
|
|
||||||
# Third request should wait and then succeed
|
# Third request should wait and then succeed
|
||||||
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
result3 = await algo.require_access(
|
||||||
|
sample_query,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'wait_test'
|
||||||
|
)
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
assert result3 is True, 'After wait, request should succeed'
|
assert result3 is True, "After wait, request should succeed"
|
||||||
# Should have waited approximately until next window
|
# Should have waited approximately until next window
|
||||||
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
|
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
|
||||||
# Note: This is a timing-sensitive test, so we use a generous tolerance
|
# Note: This is a timing-sensitive test, so we use a generous tolerance
|
||||||
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
|
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
|
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||||
@@ -253,7 +289,11 @@ class TestFixedWindowAlgo:
|
|||||||
await algo.initialize()
|
await algo.initialize()
|
||||||
|
|
||||||
# release_access is empty in current implementation
|
# release_access is empty in current implementation
|
||||||
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
await algo.release_access(
|
||||||
|
sample_query_with_rate_limit,
|
||||||
|
provider_session.LauncherTypes.PERSON,
|
||||||
|
'12345'
|
||||||
|
)
|
||||||
|
|
||||||
# Should not raise or change state
|
# Should not raise or change state
|
||||||
assert 'person_12345' not in algo.containers
|
assert 'person_12345' not in algo.containers
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def make_session():
|
|||||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||||
launcher_id=12345,
|
launcher_id=12345,
|
||||||
sender_id=12345,
|
sender_id=12345,
|
||||||
use_prompt_name='default',
|
use_prompt_name="default",
|
||||||
using_conversation=None,
|
using_conversation=None,
|
||||||
conversations=[],
|
conversations=[],
|
||||||
)
|
)
|
||||||
@@ -93,9 +93,11 @@ class TestResponseWrapperMessageChain:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
|
query.resp_messages = [
|
||||||
|
platform_message.MessageChain([platform_message.Plain(text="response")])
|
||||||
|
]
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -123,7 +125,7 @@ class TestResponseWrapperCommand:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -131,7 +133,7 @@ class TestResponseWrapperCommand:
|
|||||||
command_resp = Mock()
|
command_resp = Mock()
|
||||||
command_resp.role = 'command'
|
command_resp.role = 'command'
|
||||||
command_resp.get_content_platform_message_chain = Mock(
|
command_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [command_resp]
|
query.resp_messages = [command_resp]
|
||||||
|
|
||||||
@@ -161,7 +163,7 @@ class TestResponseWrapperPlugin:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -169,7 +171,7 @@ class TestResponseWrapperPlugin:
|
|||||||
plugin_resp = Mock()
|
plugin_resp = Mock()
|
||||||
plugin_resp.role = 'plugin'
|
plugin_resp.role = 'plugin'
|
||||||
plugin_resp.get_content_platform_message_chain = Mock(
|
plugin_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [plugin_resp]
|
query.resp_messages = [plugin_resp]
|
||||||
|
|
||||||
@@ -209,17 +211,17 @@ class TestResponseWrapperAssistant:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
# Create assistant response with content
|
# Create assistant response with content
|
||||||
assistant_resp = Mock()
|
assistant_resp = Mock()
|
||||||
assistant_resp.role = 'assistant'
|
assistant_resp.role = 'assistant'
|
||||||
assistant_resp.content = 'Hello back!'
|
assistant_resp.content = "Hello back!"
|
||||||
assistant_resp.tool_calls = None
|
assistant_resp.tool_calls = None
|
||||||
assistant_resp.get_content_platform_message_chain = Mock(
|
assistant_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [assistant_resp]
|
query.resp_messages = [assistant_resp]
|
||||||
|
|
||||||
@@ -245,7 +247,7 @@ class TestResponseWrapperAssistant:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -290,7 +292,7 @@ class TestResponseWrapperAssistant:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
@@ -301,10 +303,10 @@ class TestResponseWrapperAssistant:
|
|||||||
|
|
||||||
assistant_resp = Mock()
|
assistant_resp = Mock()
|
||||||
assistant_resp.role = 'assistant'
|
assistant_resp.role = 'assistant'
|
||||||
assistant_resp.content = 'Processing...'
|
assistant_resp.content = "Processing..."
|
||||||
assistant_resp.tool_calls = [mock_tool_call]
|
assistant_resp.tool_calls = [mock_tool_call]
|
||||||
assistant_resp.get_content_platform_message_chain = Mock(
|
assistant_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [assistant_resp]
|
query.resp_messages = [assistant_resp]
|
||||||
|
|
||||||
@@ -344,17 +346,17 @@ class TestResponseWrapperInterrupt:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
# Create assistant response with content
|
# Create assistant response with content
|
||||||
assistant_resp = Mock()
|
assistant_resp = Mock()
|
||||||
assistant_resp.role = 'assistant'
|
assistant_resp.role = 'assistant'
|
||||||
assistant_resp.content = 'Hello!'
|
assistant_resp.content = "Hello!"
|
||||||
assistant_resp.tool_calls = None
|
assistant_resp.tool_calls = None
|
||||||
assistant_resp.get_content_platform_message_chain = Mock(
|
assistant_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [assistant_resp]
|
query.resp_messages = [assistant_resp]
|
||||||
|
|
||||||
@@ -382,7 +384,7 @@ class TestResponseWrapperCustomReply:
|
|||||||
app.sess_mgr.get_session = AsyncMock(return_value=session)
|
app.sess_mgr.get_session = AsyncMock(return_value=session)
|
||||||
|
|
||||||
# Mock plugin connector with custom reply
|
# Mock plugin connector with custom reply
|
||||||
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
|
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
|
||||||
mock_event_ctx = Mock()
|
mock_event_ctx = Mock()
|
||||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||||
mock_event_ctx.event = Mock()
|
mock_event_ctx.event = Mock()
|
||||||
@@ -395,17 +397,17 @@ class TestResponseWrapperCustomReply:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
|
|
||||||
# Create assistant response
|
# Create assistant response
|
||||||
assistant_resp = Mock()
|
assistant_resp = Mock()
|
||||||
assistant_resp.role = 'assistant'
|
assistant_resp.role = 'assistant'
|
||||||
assistant_resp.content = 'Default reply'
|
assistant_resp.content = "Default reply"
|
||||||
assistant_resp.tool_calls = None
|
assistant_resp.tool_calls = None
|
||||||
assistant_resp.get_content_platform_message_chain = Mock(
|
assistant_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [assistant_resp]
|
query.resp_messages = [assistant_resp]
|
||||||
|
|
||||||
@@ -419,7 +421,7 @@ class TestResponseWrapperCustomReply:
|
|||||||
assert len(results[0].new_query.resp_message_chain) == 1
|
assert len(results[0].new_query.resp_message_chain) == 1
|
||||||
# Should be the custom chain
|
# Should be the custom chain
|
||||||
chain = results[0].new_query.resp_message_chain[0]
|
chain = results[0].new_query.resp_message_chain[0]
|
||||||
assert 'Custom reply' in str(chain)
|
assert "Custom reply" in str(chain)
|
||||||
|
|
||||||
|
|
||||||
class TestResponseWrapperVariables:
|
class TestResponseWrapperVariables:
|
||||||
@@ -450,7 +452,7 @@ class TestResponseWrapperVariables:
|
|||||||
|
|
||||||
await stage.initialize(pipeline_config)
|
await stage.initialize(pipeline_config)
|
||||||
|
|
||||||
query = text_query('hello')
|
query = text_query("hello")
|
||||||
query.pipeline_config = pipeline_config
|
query.pipeline_config = pipeline_config
|
||||||
query.resp_message_chain = []
|
query.resp_message_chain = []
|
||||||
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
|
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
|
||||||
@@ -458,10 +460,10 @@ class TestResponseWrapperVariables:
|
|||||||
# Create assistant response
|
# Create assistant response
|
||||||
assistant_resp = Mock()
|
assistant_resp = Mock()
|
||||||
assistant_resp.role = 'assistant'
|
assistant_resp.role = 'assistant'
|
||||||
assistant_resp.content = 'Hello'
|
assistant_resp.content = "Hello"
|
||||||
assistant_resp.tool_calls = None
|
assistant_resp.tool_calls = None
|
||||||
assistant_resp.get_content_platform_message_chain = Mock(
|
assistant_resp.get_content_platform_message_chain = Mock(
|
||||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
|
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
|
||||||
)
|
)
|
||||||
query.resp_messages = [assistant_resp]
|
query.resp_messages = [assistant_resp]
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Tests cover:
|
|||||||
- RAG methods (ingest, retrieve, schema)
|
- RAG methods (ingest, retrieve, schema)
|
||||||
- Disabled plugin early returns
|
- Disabled plugin early returns
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -87,12 +86,16 @@ class TestListPlugins:
|
|||||||
return_value=[
|
return_value=[
|
||||||
{
|
{
|
||||||
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
|
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
|
||||||
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}],
|
'components': [
|
||||||
|
{'manifest': {'manifest': {'kind': 'Command'}}}
|
||||||
|
],
|
||||||
'debug': False,
|
'debug': False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
|
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
|
||||||
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}],
|
'components': [
|
||||||
|
{'manifest': {'manifest': {'kind': 'Tool'}}}
|
||||||
|
],
|
||||||
'debug': False,
|
'debug': False,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@@ -124,7 +127,9 @@ class TestListPlugins:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([])))
|
connector.ap.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(__iter__=lambda self: iter([]))
|
||||||
|
)
|
||||||
|
|
||||||
result = await connector.list_plugins()
|
result = await connector.list_plugins()
|
||||||
|
|
||||||
@@ -225,8 +230,7 @@ class TestCallParser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
connector.handler.parse_document.assert_called_once_with(
|
connector.handler.parse_document.assert_called_once_with(
|
||||||
'author',
|
'author', 'parser',
|
||||||
'parser',
|
|
||||||
{'mime_type': 'text/plain', 'filename': 'test.txt'},
|
{'mime_type': 'text/plain', 'filename': 'test.txt'},
|
||||||
b'file content',
|
b'file content',
|
||||||
)
|
)
|
||||||
@@ -247,7 +251,9 @@ class TestRAGMethods:
|
|||||||
|
|
||||||
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
|
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
|
||||||
|
|
||||||
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'})
|
connector.handler.rag_ingest_document.assert_called_once_with(
|
||||||
|
'author', 'engine', {'file': 'test.pdf'}
|
||||||
|
)
|
||||||
assert result['status'] == 'success'
|
assert result['status'] == 'success'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -258,16 +264,14 @@ class TestRAGMethods:
|
|||||||
|
|
||||||
connector.handler = AsyncMock()
|
connector.handler = AsyncMock()
|
||||||
connector.handler.retrieve_knowledge = AsyncMock(
|
connector.handler.retrieve_knowledge = AsyncMock(
|
||||||
return_value={
|
return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
|
||||||
'results': [
|
|
||||||
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
|
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
|
||||||
|
|
||||||
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'})
|
connector.handler.retrieve_knowledge.assert_called_once_with(
|
||||||
|
'author', 'engine', '', {'query': 'test'}
|
||||||
|
)
|
||||||
assert result == {
|
assert result == {
|
||||||
'results': [
|
'results': [
|
||||||
{
|
{
|
||||||
@@ -286,7 +290,9 @@ class TestRAGMethods:
|
|||||||
connector = create_mock_connector()
|
connector = create_mock_connector()
|
||||||
|
|
||||||
connector.handler = AsyncMock()
|
connector.handler = AsyncMock()
|
||||||
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}})
|
connector.handler.get_rag_creation_schema = AsyncMock(
|
||||||
|
return_value={'properties': {'name': {'type': 'string'}}}
|
||||||
|
)
|
||||||
|
|
||||||
result = await connector.get_rag_creation_schema('author/engine')
|
result = await connector.get_rag_creation_schema('author/engine')
|
||||||
|
|
||||||
@@ -320,7 +326,9 @@ class TestRAGMethods:
|
|||||||
|
|
||||||
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
|
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
|
||||||
|
|
||||||
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'})
|
connector.handler.rag_on_kb_create.assert_called_once_with(
|
||||||
|
'author', 'engine', 'kb-uuid', {'model': 'test'}
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rag_on_kb_delete(self):
|
async def test_rag_on_kb_delete(self):
|
||||||
@@ -346,7 +354,9 @@ class TestRAGMethods:
|
|||||||
|
|
||||||
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
|
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
|
||||||
|
|
||||||
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid')
|
connector.handler.rag_delete_document.assert_called_once_with(
|
||||||
|
'author', 'engine', 'doc-uuid', 'kb-uuid'
|
||||||
|
)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@@ -436,7 +446,9 @@ class TestGetPluginInfo:
|
|||||||
connector = create_mock_connector()
|
connector = create_mock_connector()
|
||||||
|
|
||||||
connector.handler = AsyncMock()
|
connector.handler = AsyncMock()
|
||||||
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}})
|
connector.handler.get_plugin_info = AsyncMock(
|
||||||
|
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
|
||||||
|
)
|
||||||
|
|
||||||
result = await connector.get_plugin_info('author', 'plugin')
|
result = await connector.get_plugin_info('author', 'plugin')
|
||||||
|
|
||||||
@@ -458,7 +470,9 @@ class TestSetPluginConfig:
|
|||||||
|
|
||||||
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
|
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
|
||||||
|
|
||||||
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'})
|
connector.handler.set_plugin_config.assert_called_once_with(
|
||||||
|
'author', 'plugin', {'setting': 'value'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestPingPluginRuntime:
|
class TestPingPluginRuntime:
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
Tests cover:
|
Tests cover:
|
||||||
- _parse_plugin_id() parsing and validation
|
- _parse_plugin_id() parsing and validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Tests cover:
|
|||||||
- Handling missing requirements.txt
|
- Handling missing requirements.txt
|
||||||
- Handling empty/malformed requirements.txt
|
- Handling empty/malformed requirements.txt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import zipfile
|
import zipfile
|
||||||
@@ -83,13 +82,13 @@ class TestExtractDepsMetadata:
|
|||||||
"""Test that comments and empty lines are filtered."""
|
"""Test that comments and empty lines are filtered."""
|
||||||
connector_instance = create_mock_connector()
|
connector_instance = create_mock_connector()
|
||||||
|
|
||||||
requirements = """# This is a comment
|
requirements = '''# This is a comment
|
||||||
requests>=2.0
|
requests>=2.0
|
||||||
|
|
||||||
# Another comment
|
# Another comment
|
||||||
flask==1.0
|
flask==1.0
|
||||||
|
|
||||||
numpy"""
|
numpy'''
|
||||||
zip_bytes = create_zip_with_requirements(requirements)
|
zip_bytes = create_zip_with_requirements(requirements)
|
||||||
|
|
||||||
task_context = Mock()
|
task_context = Mock()
|
||||||
@@ -148,9 +147,9 @@ numpy"""
|
|||||||
"""Test handling requirements.txt with only comments."""
|
"""Test handling requirements.txt with only comments."""
|
||||||
connector_instance = create_mock_connector()
|
connector_instance = create_mock_connector()
|
||||||
|
|
||||||
requirements = """# Comment 1
|
requirements = '''# Comment 1
|
||||||
# Comment 2
|
# Comment 2
|
||||||
# Comment 3"""
|
# Comment 3'''
|
||||||
zip_bytes = create_zip_with_requirements(requirements)
|
zip_bytes = create_zip_with_requirements(requirements)
|
||||||
|
|
||||||
task_context = Mock()
|
task_context = Mock()
|
||||||
|
|||||||
@@ -40,13 +40,11 @@ class TestHandlerQueryVariables:
|
|||||||
"""Test set_query_var returns error when query not found."""
|
"""Test set_query_var returns error when query not found."""
|
||||||
runtime_handler = make_handler(mock_app)
|
runtime_handler = make_handler(mock_app)
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
|
||||||
{
|
'query_id': 'nonexistent-query',
|
||||||
'query_id': 'nonexistent-query',
|
'key': 'test_var',
|
||||||
'key': 'test_var',
|
'value': 'test_value',
|
||||||
'value': 'test_value',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code != 0
|
assert response.code != 0
|
||||||
assert 'nonexistent-query' in response.message
|
assert 'nonexistent-query' in response.message
|
||||||
@@ -60,13 +58,11 @@ class TestHandlerQueryVariables:
|
|||||||
|
|
||||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
|
||||||
{
|
'query_id': 'test-query',
|
||||||
'query_id': 'test-query',
|
'key': 'test_var',
|
||||||
'key': 'test_var',
|
'value': 'test_value',
|
||||||
'value': 'test_value',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert mock_query.variables['test_var'] == 'test_value'
|
assert mock_query.variables['test_var'] == 'test_value'
|
||||||
@@ -80,12 +76,10 @@ class TestHandlerQueryVariables:
|
|||||||
|
|
||||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
|
||||||
{
|
'query_id': 'test-query',
|
||||||
'query_id': 'test-query',
|
'key': 'existing_var',
|
||||||
'key': 'existing_var',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {'value': 'existing_value'}
|
assert response.data == {'value': 'existing_value'}
|
||||||
@@ -99,11 +93,9 @@ class TestHandlerQueryVariables:
|
|||||||
|
|
||||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
|
||||||
{
|
'query_id': 'test-query',
|
||||||
'query_id': 'test-query',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {'vars': mock_query.variables}
|
assert response.data == {'vars': mock_query.variables}
|
||||||
@@ -116,7 +108,7 @@ class TestHandlerRagErrorResponse:
|
|||||||
"""Test basic error response creation."""
|
"""Test basic error response creation."""
|
||||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||||
|
|
||||||
error = Exception('test error')
|
error = Exception("test error")
|
||||||
response = _make_rag_error_response(error, 'TestError')
|
response = _make_rag_error_response(error, 'TestError')
|
||||||
|
|
||||||
# ActionResponse is a pydantic model, check message field
|
# ActionResponse is a pydantic model, check message field
|
||||||
@@ -128,8 +120,13 @@ class TestHandlerRagErrorResponse:
|
|||||||
"""Test error response with extra context."""
|
"""Test error response with extra context."""
|
||||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||||
|
|
||||||
error = ValueError('invalid input')
|
error = ValueError("invalid input")
|
||||||
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test')
|
response = _make_rag_error_response(
|
||||||
|
error,
|
||||||
|
'ValidationError',
|
||||||
|
field='name',
|
||||||
|
value='test'
|
||||||
|
)
|
||||||
|
|
||||||
assert 'ValidationError' in response.message
|
assert 'ValidationError' in response.message
|
||||||
assert 'field=name' in response.message
|
assert 'field=name' in response.message
|
||||||
@@ -140,7 +137,7 @@ class TestHandlerRagErrorResponse:
|
|||||||
"""Test error response includes exception type."""
|
"""Test error response includes exception type."""
|
||||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||||
|
|
||||||
error = RuntimeError('connection failed')
|
error = RuntimeError("connection failed")
|
||||||
response = _make_rag_error_response(error, 'ConnectionError')
|
response = _make_rag_error_response(error, 'ConnectionError')
|
||||||
|
|
||||||
assert 'RuntimeError' in response.message
|
assert 'RuntimeError' in response.message
|
||||||
@@ -151,7 +148,7 @@ class TestHandlerRagErrorResponse:
|
|||||||
"""Test error response with no extra context."""
|
"""Test error response with no extra context."""
|
||||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||||
|
|
||||||
error = KeyError('missing_key')
|
error = KeyError("missing_key")
|
||||||
response = _make_rag_error_response(error, 'LookupError')
|
response = _make_rag_error_response(error, 'LookupError')
|
||||||
|
|
||||||
# No context parts means no brackets
|
# No context parts means no brackets
|
||||||
|
|||||||
@@ -47,14 +47,12 @@ class TestInitializePluginSettings:
|
|||||||
Mock(),
|
Mock(),
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
|
||||||
{
|
'plugin_author': 'test-author',
|
||||||
'plugin_author': 'test-author',
|
'plugin_name': 'test-plugin',
|
||||||
'plugin_name': 'test-plugin',
|
'install_source': 'local',
|
||||||
'install_source': 'local',
|
'install_info': {'path': '/test'},
|
||||||
'install_info': {'path': '/test'},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert app.persistence_mgr.execute_async.await_count == 2
|
assert app.persistence_mgr.execute_async.await_count == 2
|
||||||
@@ -84,14 +82,12 @@ class TestInitializePluginSettings:
|
|||||||
Mock(),
|
Mock(),
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
|
||||||
{
|
'plugin_author': 'test-author',
|
||||||
'plugin_author': 'test-author',
|
'plugin_name': 'test-plugin',
|
||||||
'plugin_name': 'test-plugin',
|
'install_source': 'github',
|
||||||
'install_source': 'github',
|
'install_info': {'repo': 'author/name'},
|
||||||
'install_info': {'repo': 'author/name'},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert app.persistence_mgr.execute_async.await_count == 3
|
assert app.persistence_mgr.execute_async.await_count == 3
|
||||||
@@ -165,7 +161,9 @@ class TestSetBinaryStorage:
|
|||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
|
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new'))
|
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
|
||||||
|
self.payload(b'new')
|
||||||
|
)
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert app.persistence_mgr.execute_async.await_count == 2
|
assert app.persistence_mgr.execute_async.await_count == 2
|
||||||
@@ -205,7 +203,9 @@ class TestSetBinaryStorage:
|
|||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
|
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x'))
|
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
|
||||||
|
self.payload(b'x')
|
||||||
|
)
|
||||||
|
|
||||||
assert response.code != 0
|
assert response.code != 0
|
||||||
assert '1 > 0 bytes' in response.message
|
assert '1 > 0 bytes' in response.message
|
||||||
@@ -228,12 +228,10 @@ class TestGetPluginSettings:
|
|||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
app.persistence_mgr.execute_async.return_value = make_result()
|
app.persistence_mgr.execute_async.return_value = make_result()
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
|
||||||
{
|
'plugin_author': 'test-author',
|
||||||
'plugin_author': 'test-author',
|
'plugin_name': 'test-plugin',
|
||||||
'plugin_name': 'test-plugin',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {
|
assert response.data == {
|
||||||
@@ -257,12 +255,10 @@ class TestGetPluginSettings:
|
|||||||
)
|
)
|
||||||
app.persistence_mgr.execute_async.return_value = make_result(setting)
|
app.persistence_mgr.execute_async.return_value = make_result(setting)
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
|
||||||
{
|
'plugin_author': 'test-author',
|
||||||
'plugin_author': 'test-author',
|
'plugin_name': 'test-plugin',
|
||||||
'plugin_name': 'test-plugin',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {
|
assert response.data == {
|
||||||
@@ -290,13 +286,11 @@ class TestGetBinaryStorage:
|
|||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
|
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
|
||||||
{
|
'key': 'test-key',
|
||||||
'key': 'test-key',
|
'owner_type': 'plugin',
|
||||||
'owner_type': 'plugin',
|
'owner': 'test-owner',
|
||||||
'owner': 'test-owner',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {
|
assert response.data == {
|
||||||
@@ -309,13 +303,11 @@ class TestGetBinaryStorage:
|
|||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
app.persistence_mgr.execute_async.return_value = make_result()
|
app.persistence_mgr.execute_async.return_value = make_result()
|
||||||
|
|
||||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
|
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
|
||||||
{
|
'key': 'test-key',
|
||||||
'key': 'test-key',
|
'owner_type': 'plugin',
|
||||||
'owner_type': 'plugin',
|
'owner': 'test-owner',
|
||||||
'owner': 'test-owner',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code != 0
|
assert response.code != 0
|
||||||
assert 'Storage with key test-key not found' in response.message
|
assert 'Storage with key test-key not found' in response.message
|
||||||
@@ -337,11 +329,9 @@ class TestHandlerQueryLookup:
|
|||||||
"""Query-bound actions return error when query_id is not cached."""
|
"""Query-bound actions return error when query_id is not cached."""
|
||||||
runtime_handler = make_handler(app)
|
runtime_handler = make_handler(app)
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
|
||||||
{
|
'query_id': 'nonexistent-query',
|
||||||
'query_id': 'nonexistent-query',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code != 0
|
assert response.code != 0
|
||||||
assert 'nonexistent-query' in response.message
|
assert 'nonexistent-query' in response.message
|
||||||
@@ -353,11 +343,9 @@ class TestHandlerQueryLookup:
|
|||||||
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
|
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
|
||||||
app.query_pool.cached_queries['existing-query'] = query
|
app.query_pool.cached_queries['existing-query'] = query
|
||||||
|
|
||||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
|
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
|
||||||
{
|
'query_id': 'existing-query',
|
||||||
'query_id': 'existing-query',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.code == 0
|
assert response.code == 0
|
||||||
assert response.data == {'bot_uuid': 'test-bot-uuid'}
|
assert response.data == {'bot_uuid': 'test-bot-uuid'}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Tests cover:
|
|||||||
- _make_rag_error_response() helper function
|
- _make_rag_error_response() helper function
|
||||||
- RuntimeConnectionHandler cleanup_plugin_data method
|
- RuntimeConnectionHandler cleanup_plugin_data method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -24,7 +23,7 @@ class TestMakeRagErrorResponse:
|
|||||||
"""Test basic error response creation."""
|
"""Test basic error response creation."""
|
||||||
handler = get_handler_module()
|
handler = get_handler_module()
|
||||||
|
|
||||||
error = ValueError('test error message')
|
error = ValueError("test error message")
|
||||||
result = handler._make_rag_error_response(error, 'TestError')
|
result = handler._make_rag_error_response(error, 'TestError')
|
||||||
|
|
||||||
# ActionResponse.error() returns code=1 (error status)
|
# ActionResponse.error() returns code=1 (error status)
|
||||||
@@ -37,7 +36,7 @@ class TestMakeRagErrorResponse:
|
|||||||
"""Test that error type is included in message."""
|
"""Test that error type is included in message."""
|
||||||
handler = get_handler_module()
|
handler = get_handler_module()
|
||||||
|
|
||||||
error = RuntimeError('something went wrong')
|
error = RuntimeError("something went wrong")
|
||||||
result = handler._make_rag_error_response(error, 'VectorStoreError')
|
result = handler._make_rag_error_response(error, 'VectorStoreError')
|
||||||
|
|
||||||
assert '[VectorStoreError/RuntimeError]' in result.message
|
assert '[VectorStoreError/RuntimeError]' in result.message
|
||||||
@@ -46,7 +45,7 @@ class TestMakeRagErrorResponse:
|
|||||||
"""Test that extra context fields are included."""
|
"""Test that extra context fields are included."""
|
||||||
handler = get_handler_module()
|
handler = get_handler_module()
|
||||||
|
|
||||||
error = Exception('embedding failed')
|
error = Exception("embedding failed")
|
||||||
result = handler._make_rag_error_response(
|
result = handler._make_rag_error_response(
|
||||||
error,
|
error,
|
||||||
'EmbeddingError',
|
'EmbeddingError',
|
||||||
@@ -72,7 +71,7 @@ class TestMakeRagErrorResponse:
|
|||||||
"""Test multiple context fields are comma separated."""
|
"""Test multiple context fields are comma separated."""
|
||||||
handler = get_handler_module()
|
handler = get_handler_module()
|
||||||
|
|
||||||
error = IOError('file not found')
|
error = IOError("file not found")
|
||||||
result = handler._make_rag_error_response(
|
result = handler._make_rag_error_response(
|
||||||
error,
|
error,
|
||||||
'FileServiceError',
|
'FileServiceError',
|
||||||
@@ -120,7 +119,9 @@ class TestCleanupPluginData:
|
|||||||
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
|
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
|
||||||
handler_instance.ap = mock_app
|
handler_instance.ap = mock_app
|
||||||
|
|
||||||
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name')
|
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
|
||||||
|
handler_instance, 'author', 'plugin-name'
|
||||||
|
)
|
||||||
|
|
||||||
# Should have at least 2 calls: one for settings, one for binary storage
|
# Should have at least 2 calls: one for settings, one for binary storage
|
||||||
assert mock_app.persistence_mgr.execute_async.call_count >= 2
|
assert mock_app.persistence_mgr.execute_async.call_count >= 2
|
||||||
@@ -88,10 +88,7 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
|
|||||||
|
|
||||||
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
|
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
|
||||||
return provider_message.Message(
|
|
||||||
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
|
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
|
||||||
"""Return fake rerank results."""
|
"""Return fake rerank results."""
|
||||||
@@ -138,10 +135,8 @@ def mock_app_for_modelmgr():
|
|||||||
|
|
||||||
# Fake persistence manager - returns empty results by default
|
# Fake persistence manager - returns empty results by default
|
||||||
app.persistence_mgr = SimpleNamespace()
|
app.persistence_mgr = SimpleNamespace()
|
||||||
|
|
||||||
async def default_execute(query):
|
async def default_execute(query):
|
||||||
return _make_mock_result([])
|
return _make_mock_result([])
|
||||||
|
|
||||||
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
|
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
|
||||||
|
|
||||||
# Fake discover engine
|
# Fake discover engine
|
||||||
@@ -170,7 +165,9 @@ def fake_requester_registry(mock_app_for_modelmgr):
|
|||||||
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
|
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
|
||||||
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
|
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
|
||||||
|
|
||||||
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component])
|
app.discover.get_components_by_kind = Mock(
|
||||||
|
return_value=[fake_component, another_component]
|
||||||
|
)
|
||||||
|
|
||||||
model_mgr = ModelManager(app)
|
model_mgr = ModelManager(app)
|
||||||
return model_mgr
|
return model_mgr
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
|
|||||||
'base-url': 'https://api.dify.ai',
|
'base-url': 'https://api.dify.ai',
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'output': {'misc': {}},
|
'output': {'misc': {}}
|
||||||
}
|
}
|
||||||
|
|
||||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||||
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
|
|||||||
'base-url': 'https://api.dify.ai',
|
'base-url': 'https://api.dify.ai',
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'output': {'misc': {}},
|
'output': {'misc': {}}
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(DifyAPIError, match='不支持'):
|
with pytest.raises(DifyAPIError, match='不支持'):
|
||||||
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
|
|||||||
'base-url': 'https://api.dify.ai',
|
'base-url': 'https://api.dify.ai',
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'output': {'misc': {}},
|
'output': {'misc': {}}
|
||||||
}
|
}
|
||||||
|
|
||||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||||
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
|
|||||||
'base-url': 'https://api.dify.ai',
|
'base-url': 'https://api.dify.ai',
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'output': {'misc': {}},
|
'output': {'misc': {}}
|
||||||
}
|
}
|
||||||
|
|
||||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||||
|
|
||||||
assert runner.pipeline_config == pipeline_config
|
assert runner.pipeline_config == pipeline_config
|
||||||
assert runner.ap == mock_app
|
assert runner.ap == mock_app
|
||||||
@@ -1062,7 +1062,9 @@ class TestScanModels:
|
|||||||
|
|
||||||
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
|
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
|
||||||
mock_get_model_info.side_effect = (
|
mock_get_model_info.side_effect = (
|
||||||
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {}
|
lambda model: {'max_input_tokens': 131072}
|
||||||
|
if model == 'moonshot/moonshot-v1-128k'
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert requester._safe_context_length('moonshot-v1-128k') == 131072
|
assert requester._safe_context_length('moonshot-v1-128k') == 131072
|
||||||
|
|||||||
@@ -635,9 +635,7 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_manager_load_llm_model_with_provider(
|
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
|
||||||
):
|
|
||||||
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
|
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
|
||||||
model_mgr = fake_requester_registry
|
model_mgr = fake_requester_registry
|
||||||
|
|
||||||
@@ -650,9 +648,7 @@ async def test_model_manager_load_llm_model_with_provider(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_manager_load_llm_model_with_provider_from_row(
|
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
|
||||||
):
|
|
||||||
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
|
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
|
||||||
model_mgr = fake_requester_registry
|
model_mgr = fake_requester_registry
|
||||||
|
|
||||||
@@ -665,9 +661,7 @@ async def test_model_manager_load_llm_model_with_provider_from_row(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_manager_load_embedding_model_with_provider(
|
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
|
||||||
):
|
|
||||||
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
|
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
|
||||||
model_mgr = fake_requester_registry
|
model_mgr = fake_requester_registry
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class TestableRequester(requester.ProviderAPIRequester):
|
|||||||
remove_think=False,
|
remove_think=False,
|
||||||
):
|
):
|
||||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
return provider_message.Message(
|
return provider_message.Message(
|
||||||
role='assistant',
|
role='assistant',
|
||||||
content=[provider_message.ContentElement(type='text', text='Testable response')],
|
content=[provider_message.ContentElement(type='text', text='Testable response')],
|
||||||
@@ -290,9 +289,7 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
|
|||||||
current_stage_name=None,
|
current_stage_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await provider.invoke_llm(query, runtime_llm_model, messages)
|
result = await provider.invoke_llm(query, runtime_llm_model, messages)
|
||||||
|
|
||||||
@@ -333,9 +330,7 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
|
|||||||
current_stage_name=None,
|
current_stage_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
|
||||||
]
|
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
|
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
|
||||||
@@ -581,9 +576,7 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
|
|||||||
current_stage_name=None,
|
current_stage_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
|
||||||
]
|
|
||||||
|
|
||||||
with pytest.raises(RequesterError):
|
with pytest.raises(RequesterError):
|
||||||
await provider.invoke_llm(query, model, messages)
|
await provider.invoke_llm(query, model, messages)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ Tests cover:
|
|||||||
- Conversation creation with prompts
|
- Conversation creation with prompts
|
||||||
- Session concurrency semaphore
|
- Session concurrency semaphore
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -61,7 +60,11 @@ class TestSessionManagerGetSession:
|
|||||||
"""Create mock app with instance config."""
|
"""Create mock app with instance config."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
mock_app.instance_config = Mock()
|
mock_app.instance_config = Mock()
|
||||||
mock_app.instance_config.data = {'concurrency': {'session': 5}}
|
mock_app.instance_config.data = {
|
||||||
|
'concurrency': {
|
||||||
|
'session': 5
|
||||||
|
}
|
||||||
|
}
|
||||||
return mock_app
|
return mock_app
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -170,7 +173,11 @@ class TestSessionManagerGetConversation:
|
|||||||
"""Create mock app with instance config."""
|
"""Create mock app with instance config."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
mock_app.instance_config = Mock()
|
mock_app.instance_config = Mock()
|
||||||
mock_app.instance_config.data = {'concurrency': {'session': 5}}
|
mock_app.instance_config.data = {
|
||||||
|
'concurrency': {
|
||||||
|
'session': 5
|
||||||
|
}
|
||||||
|
}
|
||||||
return mock_app
|
return mock_app
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -194,13 +201,17 @@ class TestSessionManagerGetConversation:
|
|||||||
return query
|
return query
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session):
|
async def test_creates_conversation_with_prompt(
|
||||||
|
self, mock_app_with_config, sample_query, sample_session
|
||||||
|
):
|
||||||
"""Test that get_conversation creates conversation with prompt."""
|
"""Test that get_conversation creates conversation with prompt."""
|
||||||
sessionmgr = get_session_module()
|
sessionmgr = get_session_module()
|
||||||
|
|
||||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||||
|
|
||||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
prompt_config = [
|
||||||
|
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||||
|
]
|
||||||
pipeline_uuid = 'pipeline-123'
|
pipeline_uuid = 'pipeline-123'
|
||||||
bot_uuid = 'bot-123'
|
bot_uuid = 'bot-123'
|
||||||
|
|
||||||
@@ -223,15 +234,21 @@ class TestSessionManagerGetConversation:
|
|||||||
|
|
||||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||||
|
|
||||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
prompt_config = [
|
||||||
|
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||||
|
]
|
||||||
pipeline_uuid = 'pipeline-123'
|
pipeline_uuid = 'pipeline-123'
|
||||||
bot_uuid = 'bot-123'
|
bot_uuid = 'bot-123'
|
||||||
|
|
||||||
# First call creates conversation
|
# First call creates conversation
|
||||||
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
|
conv1 = await manager.get_conversation(
|
||||||
|
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
|
||||||
|
)
|
||||||
|
|
||||||
# Second call with same pipeline should return same conversation
|
# Second call with same pipeline should return same conversation
|
||||||
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
|
conv2 = await manager.get_conversation(
|
||||||
|
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
|
||||||
|
)
|
||||||
|
|
||||||
assert conv1 is conv2
|
assert conv1 is conv2
|
||||||
assert len(sample_session.conversations) == 1
|
assert len(sample_session.conversations) == 1
|
||||||
@@ -245,26 +262,36 @@ class TestSessionManagerGetConversation:
|
|||||||
|
|
||||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||||
|
|
||||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
prompt_config = [
|
||||||
|
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||||
|
]
|
||||||
|
|
||||||
# First call with pipeline1
|
# First call with pipeline1
|
||||||
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1')
|
conv1 = await manager.get_conversation(
|
||||||
|
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
|
||||||
|
)
|
||||||
|
|
||||||
# Second call with different pipeline should create new conversation
|
# Second call with different pipeline should create new conversation
|
||||||
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2')
|
conv2 = await manager.get_conversation(
|
||||||
|
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
|
||||||
|
)
|
||||||
|
|
||||||
assert conv1 is not conv2
|
assert conv1 is not conv2
|
||||||
assert len(sample_session.conversations) == 2
|
assert len(sample_session.conversations) == 2
|
||||||
assert sample_session.using_conversation is conv2
|
assert sample_session.using_conversation is conv2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session):
|
async def test_conversation_has_empty_messages(
|
||||||
|
self, mock_app_with_config, sample_query, sample_session
|
||||||
|
):
|
||||||
"""Test that created conversation has empty messages list."""
|
"""Test that created conversation has empty messages list."""
|
||||||
sessionmgr = get_session_module()
|
sessionmgr = get_session_module()
|
||||||
|
|
||||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||||
|
|
||||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
prompt_config = [
|
||||||
|
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||||
|
]
|
||||||
|
|
||||||
conversation = await manager.get_conversation(
|
conversation = await manager.get_conversation(
|
||||||
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
||||||
@@ -273,17 +300,22 @@ class TestSessionManagerGetConversation:
|
|||||||
assert conversation.messages == []
|
assert conversation.messages == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session):
|
async def test_prompt_messages_from_config(
|
||||||
|
self, mock_app_with_config, sample_query, sample_session
|
||||||
|
):
|
||||||
"""Test that prompt messages are created from prompt_config."""
|
"""Test that prompt messages are created from prompt_config."""
|
||||||
sessionmgr = get_session_module()
|
sessionmgr = get_session_module()
|
||||||
|
|
||||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||||
|
|
||||||
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}]
|
prompt_config = [
|
||||||
|
{'role': 'system', 'content': 'System message'},
|
||||||
|
{'role': 'user', 'content': 'User message'}
|
||||||
|
]
|
||||||
|
|
||||||
conversation = await manager.get_conversation(
|
conversation = await manager.get_conversation(
|
||||||
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
||||||
)
|
)
|
||||||
|
|
||||||
assert conversation.prompt.name == 'default'
|
assert conversation.prompt.name == 'default'
|
||||||
assert len(conversation.prompt.messages) == 2
|
assert len(conversation.prompt.messages) == 2
|
||||||
@@ -136,7 +136,6 @@ class TestToolManagerSchemaGeneration:
|
|||||||
assert 'description' in func
|
assert 'description' in func
|
||||||
assert 'parameters' in func
|
assert 'parameters' in func
|
||||||
|
|
||||||
|
|
||||||
class TestToolManagerExecuteFuncCall:
|
class TestToolManagerExecuteFuncCall:
|
||||||
"""Tests for execute_func_call method."""
|
"""Tests for execute_func_call method."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
Tests cover:
|
Tests cover:
|
||||||
- _to_i18n_name() static method
|
- _to_i18n_name() static method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@@ -61,4 +60,4 @@ class TestToI18nName:
|
|||||||
kbmgr = get_kbmgr_module()
|
kbmgr = get_kbmgr_module()
|
||||||
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
|
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
|
||||||
result = kbmgr.RAGManager._to_i18n_name(input_dict)
|
result = kbmgr.RAGManager._to_i18n_name(input_dict)
|
||||||
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
|
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
|
||||||
@@ -6,7 +6,6 @@ Tests cover:
|
|||||||
- Knowledge engine enrichment
|
- Knowledge engine enrichment
|
||||||
- KB loading and removal
|
- KB loading and removal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -102,9 +101,13 @@ class TestRAGManagerCreateKnowledgeBase:
|
|||||||
rag_module = get_rag_module()
|
rag_module = get_rag_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
|
|
||||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
|
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||||
|
return_value=[{'plugin_id': 'author/engine'}]
|
||||||
|
)
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock()
|
mock_app.persistence_mgr.execute_async = AsyncMock()
|
||||||
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin error'))
|
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
|
||||||
|
side_effect=Exception('Plugin error')
|
||||||
|
)
|
||||||
|
|
||||||
manager = rag_module.RAGManager(mock_app)
|
manager = rag_module.RAGManager(mock_app)
|
||||||
|
|
||||||
@@ -125,7 +128,9 @@ class TestRAGManagerCreateKnowledgeBase:
|
|||||||
rag_module = get_rag_module()
|
rag_module = get_rag_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
|
|
||||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
|
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||||
|
return_value=[{'plugin_id': 'author/engine'}]
|
||||||
|
)
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock()
|
mock_app.persistence_mgr.execute_async = AsyncMock()
|
||||||
mock_app.plugin_connector.rag_on_kb_create = AsyncMock()
|
mock_app.plugin_connector.rag_on_kb_create = AsyncMock()
|
||||||
|
|
||||||
@@ -201,7 +206,9 @@ class TestRuntimeKnowledgeBaseOnKBCreate:
|
|||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_kb = create_mock_kb_entity()
|
mock_kb = create_mock_kb_entity()
|
||||||
|
|
||||||
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin failed'))
|
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
|
||||||
|
side_effect=Exception('Plugin failed')
|
||||||
|
)
|
||||||
|
|
||||||
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
||||||
|
|
||||||
@@ -238,7 +245,9 @@ class TestRuntimeKnowledgeBaseIngestDocument:
|
|||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_kb = create_mock_kb_entity()
|
mock_kb = create_mock_kb_entity()
|
||||||
|
|
||||||
mock_app.plugin_connector.call_rag_ingest = AsyncMock(return_value={'status': 'success'})
|
mock_app.plugin_connector.call_rag_ingest = AsyncMock(
|
||||||
|
return_value={'status': 'success'}
|
||||||
|
)
|
||||||
|
|
||||||
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
||||||
|
|
||||||
@@ -295,10 +304,14 @@ class TestRAGManagerLoadKnowledgeBasesFromDB:
|
|||||||
# KB that will cause initialize to fail
|
# KB that will cause initialize to fail
|
||||||
mock_kb = create_mock_kb_entity()
|
mock_kb = create_mock_kb_entity()
|
||||||
|
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb])))
|
mock_app.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(all=Mock(return_value=[mock_kb]))
|
||||||
|
)
|
||||||
|
|
||||||
# Make initialize fail by having plugin_connector throw error
|
# Make initialize fail by having plugin_connector throw error
|
||||||
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Init failed'))
|
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
|
||||||
|
side_effect=Exception('Init failed')
|
||||||
|
)
|
||||||
|
|
||||||
manager = rag_module.RAGManager(mock_app)
|
manager = rag_module.RAGManager(mock_app)
|
||||||
# Should not raise - errors are caught
|
# Should not raise - errors are caught
|
||||||
@@ -398,7 +411,9 @@ class TestRuntimeKnowledgeBaseRetrieve:
|
|||||||
mock_kb = create_mock_kb_entity()
|
mock_kb = create_mock_kb_entity()
|
||||||
mock_kb.retrieval_settings = {}
|
mock_kb.retrieval_settings = {}
|
||||||
|
|
||||||
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(return_value={'results': []})
|
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(
|
||||||
|
return_value={'results': []}
|
||||||
|
)
|
||||||
|
|
||||||
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
|
||||||
|
|
||||||
@@ -667,7 +682,9 @@ class TestRAGManagerGetAllDetails:
|
|||||||
"""Test returns empty list when no knowledge bases."""
|
"""Test returns empty list when no knowledge bases."""
|
||||||
rag_module = get_rag_module()
|
rag_module = get_rag_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
mock_app.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(all=Mock(return_value=[]))
|
||||||
|
)
|
||||||
|
|
||||||
manager = rag_module.RAGManager(mock_app)
|
manager = rag_module.RAGManager(mock_app)
|
||||||
result = await manager.get_all_knowledge_base_details()
|
result = await manager.get_all_knowledge_base_details()
|
||||||
@@ -682,7 +699,9 @@ class TestRAGManagerGetAllDetails:
|
|||||||
|
|
||||||
# Mock DB result
|
# Mock DB result
|
||||||
mock_kb_row = Mock()
|
mock_kb_row = Mock()
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb_row])))
|
mock_app.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(all=Mock(return_value=[mock_kb_row]))
|
||||||
|
)
|
||||||
mock_app.persistence_mgr.serialize_model = Mock(
|
mock_app.persistence_mgr.serialize_model = Mock(
|
||||||
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
|
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
|
||||||
)
|
)
|
||||||
@@ -705,7 +724,9 @@ class TestRAGManagerGetDetails:
|
|||||||
"""Test returns None when KB doesn't exist."""
|
"""Test returns None when KB doesn't exist."""
|
||||||
rag_module = get_rag_module()
|
rag_module = get_rag_module()
|
||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=None)))
|
mock_app.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(first=Mock(return_value=None))
|
||||||
|
)
|
||||||
|
|
||||||
manager = rag_module.RAGManager(mock_app)
|
manager = rag_module.RAGManager(mock_app)
|
||||||
result = await manager.get_knowledge_base_details('nonexistent')
|
result = await manager.get_knowledge_base_details('nonexistent')
|
||||||
@@ -719,7 +740,9 @@ class TestRAGManagerGetDetails:
|
|||||||
mock_app = create_mock_app()
|
mock_app = create_mock_app()
|
||||||
|
|
||||||
mock_kb_row = Mock()
|
mock_kb_row = Mock()
|
||||||
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=mock_kb_row)))
|
mock_app.persistence_mgr.execute_async = AsyncMock(
|
||||||
|
return_value=Mock(first=Mock(return_value=mock_kb_row))
|
||||||
|
)
|
||||||
mock_app.persistence_mgr.serialize_model = Mock(
|
mock_app.persistence_mgr.serialize_model = Mock(
|
||||||
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
|
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
|
||||||
)
|
)
|
||||||
@@ -768,4 +791,4 @@ class TestRAGManagerLoadKnowledgeBase:
|
|||||||
|
|
||||||
await manager.load_knowledge_base(kb_dict)
|
await manager.load_knowledge_base(kb_dict)
|
||||||
|
|
||||||
assert 'kb-uuid' in manager.knowledge_bases
|
assert 'kb-uuid' in manager.knowledge_bases
|
||||||
@@ -121,12 +121,10 @@ class TestRAGRuntimeServiceVectorSearch:
|
|||||||
"""Create mock app."""
|
"""Create mock app."""
|
||||||
mock_app = MagicMock()
|
mock_app = MagicMock()
|
||||||
mock_app.vector_db_mgr = MagicMock()
|
mock_app.vector_db_mgr = MagicMock()
|
||||||
mock_app.vector_db_mgr.search = AsyncMock(
|
mock_app.vector_db_mgr.search = AsyncMock(return_value=[
|
||||||
return_value=[
|
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
|
||||||
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
|
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
|
||||||
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
return mock_app
|
return mock_app
|
||||||
|
|
||||||
def _make_rag_import_mocks(self):
|
def _make_rag_import_mocks(self):
|
||||||
@@ -303,7 +301,10 @@ class TestRAGRuntimeServiceVectorList:
|
|||||||
mock_app = MagicMock()
|
mock_app = MagicMock()
|
||||||
mock_app.vector_db_mgr = MagicMock()
|
mock_app.vector_db_mgr = MagicMock()
|
||||||
mock_app.vector_db_mgr.list_by_filter = AsyncMock(
|
mock_app.vector_db_mgr.list_by_filter = AsyncMock(
|
||||||
return_value=([{'id': 'id1', 'metadata': {'file_id': 'abc'}}], 10)
|
return_value=(
|
||||||
|
[{'id': 'id1', 'metadata': {'file_id': 'abc'}}],
|
||||||
|
10
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return mock_app
|
return mock_app
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ from langbot.pkg.storage.providers.localstorage import LocalStorageProvider
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def storage_provider(tmp_path):
|
def storage_provider(tmp_path):
|
||||||
"""Create a LocalStorageProvider with a temporary storage path."""
|
"""Create a LocalStorageProvider with a temporary storage path."""
|
||||||
storage_path = str(tmp_path / 'storage')
|
storage_path = str(tmp_path / "storage")
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
provider = LocalStorageProvider(mock_app)
|
provider = LocalStorageProvider(mock_app)
|
||||||
yield provider, storage_path
|
yield provider, storage_path
|
||||||
@@ -35,15 +35,15 @@ class TestPathTraversalPrevention:
|
|||||||
async def test_absolute_path_save_rejected(self, storage_provider, tmp_path):
|
async def test_absolute_path_save_rejected(self, storage_provider, tmp_path):
|
||||||
"""Saving with an absolute path key must be blocked."""
|
"""Saving with an absolute path key must be blocked."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
target_file = str(tmp_path / 'pwned.txt')
|
target_file = str(tmp_path / "pwned.txt")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
with pytest.raises((ValueError, PermissionError)):
|
with pytest.raises((ValueError, PermissionError)):
|
||||||
await provider.save(target_file, b'malicious content')
|
await provider.save(target_file, b"malicious content")
|
||||||
|
|
||||||
# The file must NOT exist outside the storage directory
|
# The file must NOT exist outside the storage directory
|
||||||
assert not os.path.exists(target_file), (
|
assert not os.path.exists(target_file), (
|
||||||
f'Path traversal succeeded: file was written outside storage to {target_file}'
|
f"Path traversal succeeded: file was written outside storage to {target_file}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -52,28 +52,32 @@ class TestPathTraversalPrevention:
|
|||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
# Create a file outside the storage directory
|
# Create a file outside the storage directory
|
||||||
target_file = str(tmp_path / 'secret.txt')
|
target_file = str(tmp_path / "secret.txt")
|
||||||
with open(target_file, 'wb') as f:
|
with open(target_file, "wb") as f:
|
||||||
f.write(b'secret data')
|
f.write(b"secret data")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
||||||
data = await provider.load(target_file)
|
data = await provider.load(target_file)
|
||||||
assert data != b'secret data', 'Path traversal succeeded: read file outside storage'
|
assert data != b"secret data", (
|
||||||
|
"Path traversal succeeded: read file outside storage"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path):
|
async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path):
|
||||||
"""Exists check with an absolute path key must be blocked or return False."""
|
"""Exists check with an absolute path key must be blocked or return False."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
target_file = str(tmp_path / 'check_me.txt')
|
target_file = str(tmp_path / "check_me.txt")
|
||||||
with open(target_file, 'wb') as f:
|
with open(target_file, "wb") as f:
|
||||||
f.write(b'data')
|
f.write(b"data")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
try:
|
try:
|
||||||
result = await provider.exists(target_file)
|
result = await provider.exists(target_file)
|
||||||
assert result is False, 'Path traversal succeeded: exists() returned True for file outside storage'
|
assert result is False, (
|
||||||
|
"Path traversal succeeded: exists() returned True for file outside storage"
|
||||||
|
)
|
||||||
except (ValueError, PermissionError):
|
except (ValueError, PermissionError):
|
||||||
pass # Expected
|
pass # Expected
|
||||||
|
|
||||||
@@ -82,26 +86,28 @@ class TestPathTraversalPrevention:
|
|||||||
"""Deleting with an absolute path key must be blocked."""
|
"""Deleting with an absolute path key must be blocked."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
target_file = str(tmp_path / 'do_not_delete.txt')
|
target_file = str(tmp_path / "do_not_delete.txt")
|
||||||
with open(target_file, 'wb') as f:
|
with open(target_file, "wb") as f:
|
||||||
f.write(b'important data')
|
f.write(b"important data")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
||||||
await provider.delete(target_file)
|
await provider.delete(target_file)
|
||||||
|
|
||||||
assert os.path.exists(target_file), 'Path traversal succeeded: file outside storage was deleted'
|
assert os.path.exists(target_file), (
|
||||||
|
"Path traversal succeeded: file outside storage was deleted"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_absolute_path_size_rejected(self, storage_provider, tmp_path):
|
async def test_absolute_path_size_rejected(self, storage_provider, tmp_path):
|
||||||
"""Size check with an absolute path key must be blocked."""
|
"""Size check with an absolute path key must be blocked."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
target_file = str(tmp_path / 'measure_me.txt')
|
target_file = str(tmp_path / "measure_me.txt")
|
||||||
with open(target_file, 'wb') as f:
|
with open(target_file, "wb") as f:
|
||||||
f.write(b'some data')
|
f.write(b"some data")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
||||||
await provider.size(target_file)
|
await provider.size(target_file)
|
||||||
|
|
||||||
@@ -110,39 +116,41 @@ class TestPathTraversalPrevention:
|
|||||||
"""Relative path traversal with '..' must be blocked."""
|
"""Relative path traversal with '..' must be blocked."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
target_file = str(tmp_path / 'above_storage.txt')
|
target_file = str(tmp_path / "above_storage.txt")
|
||||||
with open(target_file, 'wb') as f:
|
with open(target_file, "wb") as f:
|
||||||
f.write(b'above storage secret')
|
f.write(b"above storage secret")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
relative_key = os.path.join('..', 'above_storage.txt')
|
relative_key = os.path.join("..", "above_storage.txt")
|
||||||
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
|
||||||
data = await provider.load(relative_key)
|
data = await provider.load(relative_key)
|
||||||
assert data != b'above storage secret'
|
assert data != b"above storage secret"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path):
|
async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path):
|
||||||
"""delete_dir_recursive with traversal path must be blocked."""
|
"""delete_dir_recursive with traversal path must be blocked."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
outside_dir = tmp_path / 'outside_dir'
|
outside_dir = tmp_path / "outside_dir"
|
||||||
outside_dir.mkdir()
|
outside_dir.mkdir()
|
||||||
(outside_dir / 'file.txt').write_text('important')
|
(outside_dir / "file.txt").write_text("important")
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
with pytest.raises((ValueError, PermissionError)):
|
with pytest.raises((ValueError, PermissionError)):
|
||||||
await provider.delete_dir_recursive(str(outside_dir))
|
await provider.delete_dir_recursive(str(outside_dir))
|
||||||
|
|
||||||
assert outside_dir.exists(), 'Path traversal succeeded: directory outside storage was deleted'
|
assert outside_dir.exists(), (
|
||||||
|
"Path traversal succeeded: directory outside storage was deleted"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_legitimate_key_works(self, storage_provider):
|
async def test_legitimate_key_works(self, storage_provider):
|
||||||
"""Normal keys without traversal must still work."""
|
"""Normal keys without traversal must still work."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
key = 'test_image_abc123.png'
|
key = "test_image_abc123.png"
|
||||||
content = b'PNG image data'
|
content = b"PNG image data"
|
||||||
|
|
||||||
await provider.save(key, content)
|
await provider.save(key, content)
|
||||||
assert await provider.exists(key) is True
|
assert await provider.exists(key) is True
|
||||||
@@ -158,9 +166,9 @@ class TestPathTraversalPrevention:
|
|||||||
"""Keys with legitimate subdirectories must still work."""
|
"""Keys with legitimate subdirectories must still work."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
key = 'bot_log_images/img_001.png'
|
key = "bot_log_images/img_001.png"
|
||||||
content = b'PNG image data'
|
content = b"PNG image data"
|
||||||
|
|
||||||
await provider.save(key, content)
|
await provider.save(key, content)
|
||||||
assert await provider.exists(key) is True
|
assert await provider.exists(key) is True
|
||||||
@@ -173,33 +181,33 @@ class TestPathTraversalPrevention:
|
|||||||
"""delete_dir_recursive should handle non-existing directories gracefully."""
|
"""delete_dir_recursive should handle non-existing directories gracefully."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
# Try to delete a non-existing directory - should not raise
|
# Try to delete a non-existing directory - should not raise
|
||||||
await provider.delete_dir_recursive('nonexistent_dir')
|
await provider.delete_dir_recursive("nonexistent_dir")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_dir_recursive_with_files(self, storage_provider):
|
async def test_delete_dir_recursive_with_files(self, storage_provider):
|
||||||
"""delete_dir_recursive should delete directory with files inside."""
|
"""delete_dir_recursive should delete directory with files inside."""
|
||||||
provider, storage_path = storage_provider
|
provider, storage_path = storage_provider
|
||||||
|
|
||||||
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
|
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
|
||||||
# Create a directory with files
|
# Create a directory with files
|
||||||
key1 = 'test_dir/file1.txt'
|
key1 = "test_dir/file1.txt"
|
||||||
key2 = 'test_dir/file2.txt'
|
key2 = "test_dir/file2.txt"
|
||||||
await provider.save(key1, b'content1')
|
await provider.save(key1, b"content1")
|
||||||
await provider.save(key2, b'content2')
|
await provider.save(key2, b"content2")
|
||||||
|
|
||||||
# Verify files exist
|
# Verify files exist
|
||||||
assert await provider.exists(key1)
|
assert await provider.exists(key1)
|
||||||
assert await provider.exists(key2)
|
assert await provider.exists(key2)
|
||||||
|
|
||||||
# Delete directory recursively
|
# Delete directory recursively
|
||||||
await provider.delete_dir_recursive('test_dir')
|
await provider.delete_dir_recursive("test_dir")
|
||||||
|
|
||||||
# Verify files no longer exist
|
# Verify files no longer exist
|
||||||
assert not await provider.exists(key1)
|
assert not await provider.exists(key1)
|
||||||
assert not await provider.exists(key2)
|
assert not await provider.exists(key2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Tests cover:
|
|||||||
|
|
||||||
Uses moto library to mock AWS S3 service.
|
Uses moto library to mock AWS S3 service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -45,10 +44,8 @@ def mock_app_with_s3_config():
|
|||||||
def s3_mock():
|
def s3_mock():
|
||||||
"""Set up moto S3 mock context."""
|
"""Set up moto S3 mock context."""
|
||||||
from moto import mock_aws
|
from moto import mock_aws
|
||||||
|
|
||||||
with mock_aws():
|
with mock_aws():
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
# Create bucket for tests that need pre-existing bucket
|
# Create bucket for tests that need pre-existing bucket
|
||||||
s3 = boto3.client('s3', region_name='us-east-1')
|
s3 = boto3.client('s3', region_name='us-east-1')
|
||||||
yield s3
|
yield s3
|
||||||
@@ -328,4 +325,4 @@ class TestS3StorageProviderErrorHandling:
|
|||||||
await provider.initialize()
|
await provider.initialize()
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await provider.size('nonexistent.txt')
|
await provider.size('nonexistent.txt')
|
||||||
@@ -31,7 +31,7 @@ class TestStorageMgr:
|
|||||||
|
|
||||||
storage_mgr = StorageMgr(mock_app)
|
storage_mgr = StorageMgr(mock_app)
|
||||||
|
|
||||||
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
|
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
|
||||||
await storage_mgr.initialize()
|
await storage_mgr.initialize()
|
||||||
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
||||||
mock_app.logger.info.assert_called()
|
mock_app.logger.info.assert_called()
|
||||||
@@ -41,12 +41,12 @@ class TestStorageMgr:
|
|||||||
"""Should use local storage when explicitly configured."""
|
"""Should use local storage when explicitly configured."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
mock_app.instance_config = Mock()
|
mock_app.instance_config = Mock()
|
||||||
mock_app.instance_config.data = {'storage': {'use': 'local'}}
|
mock_app.instance_config.data = {"storage": {"use": "local"}}
|
||||||
mock_app.logger = Mock()
|
mock_app.logger = Mock()
|
||||||
|
|
||||||
storage_mgr = StorageMgr(mock_app)
|
storage_mgr = StorageMgr(mock_app)
|
||||||
|
|
||||||
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
|
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
|
||||||
await storage_mgr.initialize()
|
await storage_mgr.initialize()
|
||||||
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
||||||
|
|
||||||
@@ -55,12 +55,14 @@ class TestStorageMgr:
|
|||||||
"""Should use S3 storage when configured."""
|
"""Should use S3 storage when configured."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
mock_app.instance_config = Mock()
|
mock_app.instance_config = Mock()
|
||||||
mock_app.instance_config.data = {'storage': {'use': 's3', 's3': {'endpoint_url': 'https://s3.amazonaws.com'}}}
|
mock_app.instance_config.data = {
|
||||||
|
"storage": {"use": "s3", "s3": {"endpoint_url": "https://s3.amazonaws.com"}}
|
||||||
|
}
|
||||||
mock_app.logger = Mock()
|
mock_app.logger = Mock()
|
||||||
|
|
||||||
storage_mgr = StorageMgr(mock_app)
|
storage_mgr = StorageMgr(mock_app)
|
||||||
|
|
||||||
with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock):
|
with patch.object(S3StorageProvider, "initialize", new_callable=AsyncMock):
|
||||||
await storage_mgr.initialize()
|
await storage_mgr.initialize()
|
||||||
assert isinstance(storage_mgr.storage_provider, S3StorageProvider)
|
assert isinstance(storage_mgr.storage_provider, S3StorageProvider)
|
||||||
|
|
||||||
@@ -69,12 +71,12 @@ class TestStorageMgr:
|
|||||||
"""Should default to local storage for invalid storage type."""
|
"""Should default to local storage for invalid storage type."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
mock_app.instance_config = Mock()
|
mock_app.instance_config = Mock()
|
||||||
mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}}
|
mock_app.instance_config.data = {"storage": {"use": "invalid_type"}}
|
||||||
mock_app.logger = Mock()
|
mock_app.logger = Mock()
|
||||||
|
|
||||||
storage_mgr = StorageMgr(mock_app)
|
storage_mgr = StorageMgr(mock_app)
|
||||||
|
|
||||||
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
|
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
|
||||||
await storage_mgr.initialize()
|
await storage_mgr.initialize()
|
||||||
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
|
||||||
|
|
||||||
@@ -88,7 +90,9 @@ class TestStorageMgr:
|
|||||||
|
|
||||||
storage_mgr = StorageMgr(mock_app)
|
storage_mgr = StorageMgr(mock_app)
|
||||||
|
|
||||||
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init:
|
with patch.object(
|
||||||
|
LocalStorageProvider, "initialize", new_callable=AsyncMock
|
||||||
|
) as mock_init:
|
||||||
await storage_mgr.initialize()
|
await storage_mgr.initialize()
|
||||||
mock_init.assert_called_once()
|
mock_init.assert_called_once()
|
||||||
|
|
||||||
@@ -101,8 +105,8 @@ class TestStorageProviderBase:
|
|||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
|
|
||||||
# Use LocalStorageProvider as concrete implementation
|
# Use LocalStorageProvider as concrete implementation
|
||||||
with patch('os.path.exists', return_value=True):
|
with patch("os.path.exists", return_value=True):
|
||||||
with patch('os.makedirs'):
|
with patch("os.makedirs"):
|
||||||
provider = LocalStorageProvider(mock_app)
|
provider = LocalStorageProvider(mock_app)
|
||||||
assert provider.ap == mock_app
|
assert provider.ap == mock_app
|
||||||
|
|
||||||
@@ -111,12 +115,12 @@ class TestStorageProviderBase:
|
|||||||
"""Provider base initialize should be callable and do nothing."""
|
"""Provider base initialize should be callable and do nothing."""
|
||||||
mock_app = Mock()
|
mock_app = Mock()
|
||||||
|
|
||||||
with patch('os.path.exists', return_value=True):
|
with patch("os.path.exists", return_value=True):
|
||||||
with patch('os.makedirs'):
|
with patch("os.makedirs"):
|
||||||
provider = LocalStorageProvider(mock_app)
|
provider = LocalStorageProvider(mock_app)
|
||||||
# Initialize should not raise
|
# Initialize should not raise
|
||||||
await provider.initialize()
|
await provider.initialize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -8,7 +8,6 @@ Tests cover:
|
|||||||
- HTTP request success/failure scenarios
|
- HTTP request success/failure scenarios
|
||||||
- Source code bug: send_tasks should be instance variable
|
- Source code bug: send_tasks should be instance variable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -39,7 +38,6 @@ class TestTelemetryManagerInit:
|
|||||||
manager = telemetry.TelemetryManager(mock_app)
|
manager = telemetry.TelemetryManager(mock_app)
|
||||||
assert manager.telemetry_config == {}
|
assert manager.telemetry_config == {}
|
||||||
|
|
||||||
|
|
||||||
class TestTelemetryManagerInitialize:
|
class TestTelemetryManagerInitialize:
|
||||||
"""Tests for initialize() method."""
|
"""Tests for initialize() method."""
|
||||||
|
|
||||||
@@ -220,7 +218,7 @@ class TestPayloadSanitization:
|
|||||||
|
|
||||||
# All null string fields should be empty strings
|
# All null string fields should be empty strings
|
||||||
for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']:
|
for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']:
|
||||||
assert result[field] == '', f'Field {field} should be empty string, got {result[field]}'
|
assert result[field] == '', f"Field {field} should be empty string, got {result[field]}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sanitize_string_fields_preserve_values(self):
|
async def test_sanitize_string_fields_preserve_values(self):
|
||||||
@@ -420,7 +418,9 @@ class TestHTTPScenarios:
|
|||||||
manager.telemetry_config = {'url': 'https://example.com'}
|
manager.telemetry_config = {'url': 'https://example.com'}
|
||||||
|
|
||||||
mock_response = Mock(
|
mock_response = Mock(
|
||||||
status_code=200, text='{"code": 0, "msg": "success"}', json=Mock(return_value={'code': 0, 'msg': 'success'})
|
status_code=200,
|
||||||
|
text='{"code": 0, "msg": "success"}',
|
||||||
|
json=Mock(return_value={'code': 0, 'msg': 'success'})
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -448,7 +448,9 @@ class TestHTTPScenarios:
|
|||||||
manager.telemetry_config = {'url': 'https://example.com'}
|
manager.telemetry_config = {'url': 'https://example.com'}
|
||||||
|
|
||||||
mock_response = Mock(
|
mock_response = Mock(
|
||||||
status_code=500, text='Internal Server Error', json=Mock(return_value={'code': 500, 'msg': 'error'})
|
status_code=500,
|
||||||
|
text='Internal Server Error',
|
||||||
|
json=Mock(return_value={'code': 500, 'msg': 'error'})
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -476,7 +478,7 @@ class TestHTTPScenarios:
|
|||||||
mock_response = Mock(
|
mock_response = Mock(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
text='{"code": 400, "msg": "Bad Request"}',
|
text='{"code": 400, "msg": "Bad Request"}',
|
||||||
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}),
|
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'})
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
@@ -491,7 +493,7 @@ class TestHTTPScenarios:
|
|||||||
assert mock_app.logger.warning.call_count >= 1
|
assert mock_app.logger.warning.call_count >= 1
|
||||||
# Check that one of the calls contains application error info
|
# Check that one of the calls contains application error info
|
||||||
all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list]
|
all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list]
|
||||||
assert any('400' in w for w in all_warnings), f'No warning contained error code 400: {all_warnings}'
|
assert any('400' in w for w in all_warnings), f"No warning contained error code 400: {all_warnings}"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_timeout_logs_warning(self):
|
async def test_send_timeout_logs_warning(self):
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ Tests cover:
|
|||||||
Note: Do NOT use 'from __future__ import annotations' because
|
Note: Do NOT use 'from __future__ import annotations' because
|
||||||
funcschema.py expects actual type objects, not string annotations.
|
funcschema.py expects actual type objects, not string annotations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
|||||||
@@ -20,53 +20,55 @@ class TestGetQQImageDownloadableUrl:
|
|||||||
|
|
||||||
def test_basic_url(self):
|
def test_basic_url(self):
|
||||||
"""Parse basic image URL."""
|
"""Parse basic image URL."""
|
||||||
url = 'http://example.com/image.jpg'
|
url = "http://example.com/image.jpg"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
assert result_url == 'http://example.com/image.jpg'
|
assert result_url == "http://example.com/image.jpg"
|
||||||
assert query == {}
|
assert query == {}
|
||||||
|
|
||||||
def test_url_with_query_params(self):
|
def test_url_with_query_params(self):
|
||||||
"""Parse URL with query parameters."""
|
"""Parse URL with query parameters."""
|
||||||
url = 'http://example.com/image.jpg?param1=value1¶m2=value2'
|
url = "http://example.com/image.jpg?param1=value1¶m2=value2"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
assert result_url == 'http://example.com/image.jpg'
|
assert result_url == "http://example.com/image.jpg"
|
||||||
assert query == {'param1': ['value1'], 'param2': ['value2']}
|
assert query == {"param1": ["value1"], "param2": ["value2"]}
|
||||||
|
|
||||||
def test_url_with_port(self):
|
def test_url_with_port(self):
|
||||||
"""Parse URL with port number."""
|
"""Parse URL with port number."""
|
||||||
url = 'http://example.com:8080/image.jpg'
|
url = "http://example.com:8080/image.jpg"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
assert result_url == 'http://example.com:8080/image.jpg'
|
assert result_url == "http://example.com:8080/image.jpg"
|
||||||
|
|
||||||
def test_url_with_path(self):
|
def test_url_with_path(self):
|
||||||
"""Parse URL with complex path."""
|
"""Parse URL with complex path."""
|
||||||
url = 'http://example.com/path/to/image.jpg'
|
url = "http://example.com/path/to/image.jpg"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
assert result_url == 'http://example.com/path/to/image.jpg'
|
assert result_url == "http://example.com/path/to/image.jpg"
|
||||||
|
|
||||||
def test_url_with_fragment(self):
|
def test_url_with_fragment(self):
|
||||||
"""Parse URL with fragment (fragment is not part of query)."""
|
"""Parse URL with fragment (fragment is not part of query)."""
|
||||||
url = 'http://example.com/image.jpg#fragment'
|
url = "http://example.com/image.jpg#fragment"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
# Fragment is not included in query string parsing
|
# Fragment is not included in query string parsing
|
||||||
assert 'http://example.com/image.jpg' in result_url
|
assert "http://example.com/image.jpg" in result_url
|
||||||
|
|
||||||
def test_https_url(self):
|
def test_https_url(self):
|
||||||
"""Parse HTTPS URL and preserve its scheme."""
|
"""Parse HTTPS URL and preserve its scheme."""
|
||||||
url = 'https://example.com/image.jpg'
|
url = "https://example.com/image.jpg"
|
||||||
result_url, query = get_qq_image_downloadable_url(url)
|
result_url, query = get_qq_image_downloadable_url(url)
|
||||||
|
|
||||||
assert result_url == 'https://example.com/image.jpg'
|
assert result_url == "https://example.com/image.jpg"
|
||||||
assert query == {}
|
assert query == {}
|
||||||
|
|
||||||
def test_preserves_qq_https_scheme_and_query(self):
|
def test_preserves_qq_https_scheme_and_query(self):
|
||||||
"""QQ image URLs keep HTTPS and query parameters."""
|
"""QQ image URLs keep HTTPS and query parameters."""
|
||||||
result_url, query = get_qq_image_downloadable_url('https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1')
|
result_url, query = get_qq_image_downloadable_url(
|
||||||
|
'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1'
|
||||||
|
)
|
||||||
|
|
||||||
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
|
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
|
||||||
assert query == {'term': ['2'], 'is_origin': ['1']}
|
assert query == {'term': ['2'], 'is_origin': ['1']}
|
||||||
@@ -86,50 +88,50 @@ class TestExtractB64AndFormat:
|
|||||||
async def test_jpeg_data_uri(self):
|
async def test_jpeg_data_uri(self):
|
||||||
"""Extract base64 and format from JPEG data URI."""
|
"""Extract base64 and format from JPEG data URI."""
|
||||||
# Create a simple base64 string
|
# Create a simple base64 string
|
||||||
original_data = b'test image data'
|
original_data = b"test image data"
|
||||||
b64_data = base64.b64encode(original_data).decode()
|
b64_data = base64.b64encode(original_data).decode()
|
||||||
data_uri = f'data:image/jpeg;base64,{b64_data}'
|
data_uri = f"data:image/jpeg;base64,{b64_data}"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
assert result_b64 == b64_data
|
assert result_b64 == b64_data
|
||||||
assert result_format == 'jpeg'
|
assert result_format == "jpeg"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_png_data_uri(self):
|
async def test_png_data_uri(self):
|
||||||
"""Extract base64 and format from PNG data URI."""
|
"""Extract base64 and format from PNG data URI."""
|
||||||
original_data = b'test png data'
|
original_data = b"test png data"
|
||||||
b64_data = base64.b64encode(original_data).decode()
|
b64_data = base64.b64encode(original_data).decode()
|
||||||
data_uri = f'data:image/png;base64,{b64_data}'
|
data_uri = f"data:image/png;base64,{b64_data}"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
assert result_b64 == b64_data
|
assert result_b64 == b64_data
|
||||||
assert result_format == 'png'
|
assert result_format == "png"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gif_data_uri(self):
|
async def test_gif_data_uri(self):
|
||||||
"""Extract base64 and format from GIF data URI."""
|
"""Extract base64 and format from GIF data URI."""
|
||||||
original_data = b'test gif data'
|
original_data = b"test gif data"
|
||||||
b64_data = base64.b64encode(original_data).decode()
|
b64_data = base64.b64encode(original_data).decode()
|
||||||
data_uri = f'data:image/gif;base64,{b64_data}'
|
data_uri = f"data:image/gif;base64,{b64_data}"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
assert result_b64 == b64_data
|
assert result_b64 == b64_data
|
||||||
assert result_format == 'gif'
|
assert result_format == "gif"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_webp_data_uri(self):
|
async def test_webp_data_uri(self):
|
||||||
"""Extract base64 and format from WebP data URI."""
|
"""Extract base64 and format from WebP data URI."""
|
||||||
original_data = b'test webp data'
|
original_data = b"test webp data"
|
||||||
b64_data = base64.b64encode(original_data).decode()
|
b64_data = base64.b64encode(original_data).decode()
|
||||||
data_uri = f'data:image/webp;base64,{b64_data}'
|
data_uri = f"data:image/webp;base64,{b64_data}"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
assert result_b64 == b64_data
|
assert result_b64 == b64_data
|
||||||
assert result_format == 'webp'
|
assert result_format == "webp"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_complex_base64(self):
|
async def test_complex_base64(self):
|
||||||
@@ -137,7 +139,7 @@ class TestExtractB64AndFormat:
|
|||||||
# Base64 can include + and / characters
|
# Base64 can include + and / characters
|
||||||
original_data = bytes(range(256)) # All byte values
|
original_data = bytes(range(256)) # All byte values
|
||||||
b64_data = base64.b64encode(original_data).decode()
|
b64_data = base64.b64encode(original_data).decode()
|
||||||
data_uri = f'data:image/png;base64,{b64_data}'
|
data_uri = f"data:image/png;base64,{b64_data}"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
@@ -148,9 +150,9 @@ class TestExtractB64AndFormat:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_base64(self):
|
async def test_empty_base64(self):
|
||||||
"""Handle empty base64 string."""
|
"""Handle empty base64 string."""
|
||||||
data_uri = 'data:image/png;base64,'
|
data_uri = "data:image/png;base64,"
|
||||||
|
|
||||||
result_b64, result_format = await extract_b64_and_format(data_uri)
|
result_b64, result_format = await extract_b64_and_format(data_uri)
|
||||||
|
|
||||||
assert result_b64 == ''
|
assert result_b64 == ""
|
||||||
assert result_format == 'png'
|
assert result_format == "png"
|
||||||
|
|||||||
@@ -23,52 +23,52 @@ class TestImportDir:
|
|||||||
|
|
||||||
def test_calls_importlib_for_each_python_file(self, tmp_path):
|
def test_calls_importlib_for_each_python_file(self, tmp_path):
|
||||||
"""Should call importlib.import_module for each .py file."""
|
"""Should call importlib.import_module for each .py file."""
|
||||||
module_dir = tmp_path / 'test_modules'
|
module_dir = tmp_path / "test_modules"
|
||||||
module_dir.mkdir()
|
module_dir.mkdir()
|
||||||
|
|
||||||
(module_dir / '__init__.py').write_text('')
|
(module_dir / "__init__.py").write_text("")
|
||||||
(module_dir / 'module_a.py').write_text("VALUE_A = 'a'\n")
|
(module_dir / "module_a.py").write_text("VALUE_A = 'a'\n")
|
||||||
(module_dir / 'module_b.py').write_text("VALUE_B = 'b'\n")
|
(module_dir / "module_b.py").write_text("VALUE_B = 'b'\n")
|
||||||
(module_dir / 'readme.txt').write_text('not a module')
|
(module_dir / "readme.txt").write_text("not a module")
|
||||||
|
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with patch.object(importlib, 'import_module') as mock_import:
|
with patch.object(importlib, "import_module") as mock_import:
|
||||||
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
|
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
|
||||||
# Should call import_module for each .py file (excluding __init__.py)
|
# Should call import_module for each .py file (excluding __init__.py)
|
||||||
assert mock_import.call_count == 2
|
assert mock_import.call_count == 2
|
||||||
|
|
||||||
def test_skips_init_py(self, tmp_path):
|
def test_skips_init_py(self, tmp_path):
|
||||||
"""Should skip __init__.py when importing."""
|
"""Should skip __init__.py when importing."""
|
||||||
module_dir = tmp_path / 'test_modules'
|
module_dir = tmp_path / "test_modules"
|
||||||
module_dir.mkdir()
|
module_dir.mkdir()
|
||||||
|
|
||||||
(module_dir / '__init__.py').write_text('')
|
(module_dir / "__init__.py").write_text("")
|
||||||
(module_dir / 'regular.py').write_text('VALUE = 1\n')
|
(module_dir / "regular.py").write_text("VALUE = 1\n")
|
||||||
|
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with patch.object(importlib, 'import_module') as mock_import:
|
with patch.object(importlib, "import_module") as mock_import:
|
||||||
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
|
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
|
||||||
# __init__.py should be skipped
|
# __init__.py should be skipped
|
||||||
mock_import.assert_called_once()
|
mock_import.assert_called_once()
|
||||||
# The call should not include __init__
|
# The call should not include __init__
|
||||||
call_args = mock_import.call_args[0][0]
|
call_args = mock_import.call_args[0][0]
|
||||||
assert '__init__' not in call_args
|
assert "__init__" not in call_args
|
||||||
|
|
||||||
def test_ignores_non_py_files(self, tmp_path):
|
def test_ignores_non_py_files(self, tmp_path):
|
||||||
"""Should ignore non-.py files."""
|
"""Should ignore non-.py files."""
|
||||||
module_dir = tmp_path / 'test_modules'
|
module_dir = tmp_path / "test_modules"
|
||||||
module_dir.mkdir()
|
module_dir.mkdir()
|
||||||
|
|
||||||
(module_dir / 'module.py').write_text('VALUE = 1\n')
|
(module_dir / "module.py").write_text("VALUE = 1\n")
|
||||||
(module_dir / 'readme.txt').write_text('text')
|
(module_dir / "readme.txt").write_text("text")
|
||||||
(module_dir / 'data.json').write_text('{}')
|
(module_dir / "data.json").write_text("{}")
|
||||||
|
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with patch.object(importlib, 'import_module') as mock_import:
|
with patch.object(importlib, "import_module") as mock_import:
|
||||||
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
|
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
|
||||||
# Only .py files should be imported
|
# Only .py files should be imported
|
||||||
assert mock_import.call_count == 1
|
assert mock_import.call_count == 1
|
||||||
|
|
||||||
@@ -79,14 +79,14 @@ class TestImportModulesInPkg:
|
|||||||
def test_imports_modules_from_package(self, tmp_path):
|
def test_imports_modules_from_package(self, tmp_path):
|
||||||
"""Should import all modules from a package object."""
|
"""Should import all modules from a package object."""
|
||||||
mock_pkg = MagicMock()
|
mock_pkg = MagicMock()
|
||||||
mock_pkg.__file__ = str(tmp_path / '__init__.py')
|
mock_pkg.__file__ = str(tmp_path / "__init__.py")
|
||||||
|
|
||||||
(tmp_path / '__init__.py').write_text('')
|
(tmp_path / "__init__.py").write_text("")
|
||||||
(tmp_path / 'mod1.py').write_text('MOD1 = 1\n')
|
(tmp_path / "mod1.py").write_text("MOD1 = 1\n")
|
||||||
|
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with patch.object(importutil, 'import_dir') as mock_import_dir:
|
with patch.object(importutil, "import_dir") as mock_import_dir:
|
||||||
importutil.import_modules_in_pkg(mock_pkg)
|
importutil.import_modules_in_pkg(mock_pkg)
|
||||||
mock_import_dir.assert_called_once()
|
mock_import_dir.assert_called_once()
|
||||||
call_path = mock_import_dir.call_args[0][0]
|
call_path = mock_import_dir.call_args[0][0]
|
||||||
@@ -101,11 +101,11 @@ class TestImportModulesInPkgs:
|
|||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
mock_pkg1 = MagicMock()
|
mock_pkg1 = MagicMock()
|
||||||
mock_pkg1.__file__ = '/path/to/pkg1/__init__.py'
|
mock_pkg1.__file__ = "/path/to/pkg1/__init__.py"
|
||||||
mock_pkg2 = MagicMock()
|
mock_pkg2 = MagicMock()
|
||||||
mock_pkg2.__file__ = '/path/to/pkg2/__init__.py'
|
mock_pkg2.__file__ = "/path/to/pkg2/__init__.py"
|
||||||
|
|
||||||
with patch.object(importutil, 'import_modules_in_pkg') as mock_import:
|
with patch.object(importutil, "import_modules_in_pkg") as mock_import:
|
||||||
importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2])
|
importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2])
|
||||||
assert mock_import.call_count == 2
|
assert mock_import.call_count == 2
|
||||||
|
|
||||||
@@ -116,18 +116,18 @@ class TestImportDotStyleDir:
|
|||||||
def test_converts_dot_notation_to_path(self, tmp_path):
|
def test_converts_dot_notation_to_path(self, tmp_path):
|
||||||
"""Should convert dot notation to path and import."""
|
"""Should convert dot notation to path and import."""
|
||||||
# Create structure matching the dot notation
|
# Create structure matching the dot notation
|
||||||
(tmp_path / 'my').mkdir()
|
(tmp_path / "my").mkdir()
|
||||||
(tmp_path / 'my' / 'pkg').mkdir()
|
(tmp_path / "my" / "pkg").mkdir()
|
||||||
(tmp_path / 'my' / 'pkg' / 'test').mkdir()
|
(tmp_path / "my" / "pkg" / "test").mkdir()
|
||||||
|
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with patch.object(importutil, 'import_dir') as mock_import_dir:
|
with patch.object(importutil, "import_dir") as mock_import_dir:
|
||||||
importutil.import_dot_style_dir('my.pkg.test')
|
importutil.import_dot_style_dir("my.pkg.test")
|
||||||
# The path should be converted using os.path.join
|
# The path should be converted using os.path.join
|
||||||
call_path = mock_import_dir.call_args[0][0]
|
call_path = mock_import_dir.call_args[0][0]
|
||||||
# Should contain the path components joined
|
# Should contain the path components joined
|
||||||
assert 'my' in call_path
|
assert "my" in call_path
|
||||||
|
|
||||||
|
|
||||||
class TestReadResourceFile:
|
class TestReadResourceFile:
|
||||||
@@ -137,16 +137,16 @@ class TestReadResourceFile:
|
|||||||
"""Should read content from a resource file."""
|
"""Should read content from a resource file."""
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
content = importutil.read_resource_file('templates/config.yaml')
|
content = importutil.read_resource_file("templates/config.yaml")
|
||||||
assert 'admins:' in content
|
assert "admins:" in content
|
||||||
assert 'edition: community' in content
|
assert "edition: community" in content
|
||||||
|
|
||||||
def test_raises_for_nonexistent_file(self):
|
def test_raises_for_nonexistent_file(self):
|
||||||
"""Should raise exception for non-existent resource file."""
|
"""Should raise exception for non-existent resource file."""
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with pytest.raises((FileNotFoundError, Exception)):
|
with pytest.raises((FileNotFoundError, Exception)):
|
||||||
importutil.read_resource_file('nonexistent/path/file.txt')
|
importutil.read_resource_file("nonexistent/path/file.txt")
|
||||||
|
|
||||||
|
|
||||||
class TestReadResourceFileBytes:
|
class TestReadResourceFileBytes:
|
||||||
@@ -156,16 +156,16 @@ class TestReadResourceFileBytes:
|
|||||||
"""Should read content as bytes from a resource file."""
|
"""Should read content as bytes from a resource file."""
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
content = importutil.read_resource_file_bytes('templates/config.yaml')
|
content = importutil.read_resource_file_bytes("templates/config.yaml")
|
||||||
assert b'admins:' in content
|
assert b"admins:" in content
|
||||||
assert b'edition: community' in content
|
assert b"edition: community" in content
|
||||||
|
|
||||||
def test_raises_for_nonexistent_file_bytes(self):
|
def test_raises_for_nonexistent_file_bytes(self):
|
||||||
"""Should raise exception for non-existent resource file."""
|
"""Should raise exception for non-existent resource file."""
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with pytest.raises((FileNotFoundError, Exception)):
|
with pytest.raises((FileNotFoundError, Exception)):
|
||||||
importutil.read_resource_file_bytes('nonexistent/path/file.txt')
|
importutil.read_resource_file_bytes("nonexistent/path/file.txt")
|
||||||
|
|
||||||
|
|
||||||
class TestListResourceFiles:
|
class TestListResourceFiles:
|
||||||
@@ -175,9 +175,9 @@ class TestListResourceFiles:
|
|||||||
"""Should list files in a resource directory."""
|
"""Should list files in a resource directory."""
|
||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
files = importutil.list_resource_files('templates')
|
files = importutil.list_resource_files("templates")
|
||||||
assert 'config.yaml' in files
|
assert "config.yaml" in files
|
||||||
assert 'default-pipeline-config.json' in files
|
assert "default-pipeline-config.json" in files
|
||||||
assert all(isinstance(file, str) for file in files)
|
assert all(isinstance(file, str) for file in files)
|
||||||
|
|
||||||
def test_raises_for_nonexistent_directory(self):
|
def test_raises_for_nonexistent_directory(self):
|
||||||
@@ -185,8 +185,8 @@ class TestListResourceFiles:
|
|||||||
from langbot.pkg.utils import importutil
|
from langbot.pkg.utils import importutil
|
||||||
|
|
||||||
with pytest.raises((FileNotFoundError, Exception)):
|
with pytest.raises((FileNotFoundError, Exception)):
|
||||||
importutil.list_resource_files('nonexistent_directory_xyz')
|
importutil.list_resource_files("nonexistent_directory_xyz")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ Tests cover:
|
|||||||
- Docker environment detection
|
- Docker environment detection
|
||||||
- WebSocket plugin runtime mode
|
- WebSocket plugin runtime mode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -87,4 +86,4 @@ class TestGetPlatform:
|
|||||||
assert platform_module.use_websocket_to_connect_plugin_runtime() is True
|
assert platform_module.use_websocket_to_connect_plugin_runtime() is True
|
||||||
|
|
||||||
# Restore
|
# Restore
|
||||||
platform_module.standalone_runtime = original
|
platform_module.standalone_runtime = original
|
||||||
@@ -60,12 +60,10 @@ class TestProxyManager:
|
|||||||
|
|
||||||
async def test_initialize_config_overrides_env(self):
|
async def test_initialize_config_overrides_env(self):
|
||||||
"""Config proxy overrides environment variables."""
|
"""Config proxy overrides environment variables."""
|
||||||
mock_app = self._create_mock_app(
|
mock_app = self._create_mock_app(proxy_config={
|
||||||
proxy_config={
|
'http': 'http://config-proxy:8080',
|
||||||
'http': 'http://config-proxy:8080',
|
'https': 'https://config-proxy:8443',
|
||||||
'https': 'https://config-proxy:8443',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}):
|
with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}):
|
||||||
pm = ProxyManager(mock_app)
|
pm = ProxyManager(mock_app)
|
||||||
@@ -76,12 +74,10 @@ class TestProxyManager:
|
|||||||
|
|
||||||
async def test_initialize_sets_env_variables(self):
|
async def test_initialize_sets_env_variables(self):
|
||||||
"""initialize sets proxy to environment variables."""
|
"""initialize sets proxy to environment variables."""
|
||||||
mock_app = self._create_mock_app(
|
mock_app = self._create_mock_app(proxy_config={
|
||||||
proxy_config={
|
'http': 'http://test-proxy:8080',
|
||||||
'http': 'http://test-proxy:8080',
|
'https': 'https://test-proxy:8443',
|
||||||
'https': 'https://test-proxy:8443',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
pm = ProxyManager(mock_app)
|
pm = ProxyManager(mock_app)
|
||||||
await pm.initialize()
|
await pm.initialize()
|
||||||
@@ -147,11 +143,9 @@ class TestProxyManager:
|
|||||||
|
|
||||||
async def test_initialize_http_only_config(self):
|
async def test_initialize_http_only_config(self):
|
||||||
"""initialize handles http-only config."""
|
"""initialize handles http-only config."""
|
||||||
mock_app = self._create_mock_app(
|
mock_app = self._create_mock_app(proxy_config={
|
||||||
proxy_config={
|
'http': 'http://http-only:8080',
|
||||||
'http': 'http://http-only:8080',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear any existing proxy env vars
|
# Clear any existing proxy env vars
|
||||||
env_backup = {}
|
env_backup = {}
|
||||||
|
|||||||
@@ -29,63 +29,63 @@ class TestGetRunnerCategory:
|
|||||||
|
|
||||||
def test_empty_url_returns_unknown(self):
|
def test_empty_url_returns_unknown(self):
|
||||||
"""Empty or None URL should return UNKNOWN."""
|
"""Empty or None URL should return UNKNOWN."""
|
||||||
assert get_runner_category('test', '') == RunnerCategory.UNKNOWN
|
assert get_runner_category("test", "") == RunnerCategory.UNKNOWN
|
||||||
assert get_runner_category('test', None) == RunnerCategory.UNKNOWN
|
assert get_runner_category("test", None) == RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
def test_localhost_returns_local(self):
|
def test_localhost_returns_local(self):
|
||||||
"""localhost URL should be categorized as LOCAL."""
|
"""localhost URL should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://localhost:3000') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://localhost:3000") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'https://localhost') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "https://localhost") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_127_0_0_1_returns_local(self):
|
def test_127_0_0_1_returns_local(self):
|
||||||
"""127.0.0.1 URL should be categorized as LOCAL."""
|
"""127.0.0.1 URL should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://127.0.0.1:8080') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://127.0.0.1:8080") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'https://127.0.0.1') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "https://127.0.0.1") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_0_0_0_0_returns_local(self):
|
def test_0_0_0_0_returns_local(self):
|
||||||
"""0.0.0.0 URL should be categorized as LOCAL."""
|
"""0.0.0.0 URL should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://0.0.0.0:8080') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://0.0.0.0:8080") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_private_ip_192_168_returns_local(self):
|
def test_private_ip_192_168_returns_local(self):
|
||||||
"""192.168.x.x private IP should be categorized as LOCAL."""
|
"""192.168.x.x private IP should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://192.168.1.1:3000') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://192.168.1.1:3000") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'http://192.168.0.100') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://192.168.0.100") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_private_ip_10_returns_local(self):
|
def test_private_ip_10_returns_local(self):
|
||||||
"""10.x.x.x private IP should be categorized as LOCAL."""
|
"""10.x.x.x private IP should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://10.0.0.1:8080') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://10.0.0.1:8080") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'http://10.255.255.255') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://10.255.255.255") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_private_ip_172_16_31_returns_local(self):
|
def test_private_ip_172_16_31_returns_local(self):
|
||||||
"""172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL."""
|
"""172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL."""
|
||||||
assert get_runner_category('test', 'http://172.16.0.1:8080') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://172.16.0.1:8080") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'http://172.20.0.1') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://172.20.0.1") == RunnerCategory.LOCAL
|
||||||
assert get_runner_category('test', 'http://172.31.255.255') == RunnerCategory.LOCAL
|
assert get_runner_category("test", "http://172.31.255.255") == RunnerCategory.LOCAL
|
||||||
|
|
||||||
def test_n8n_cloud_returns_cloud(self):
|
def test_n8n_cloud_returns_cloud(self):
|
||||||
"""n8n.cloud domain should be categorized as CLOUD."""
|
"""n8n.cloud domain should be categorized as CLOUD."""
|
||||||
assert get_runner_category('test', 'https://myinstance.n8n.cloud') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://myinstance.n8n.cloud") == RunnerCategory.CLOUD
|
||||||
assert get_runner_category('test', 'https://test.n8n.io') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://test.n8n.io") == RunnerCategory.CLOUD
|
||||||
|
|
||||||
def test_dify_cloud_returns_cloud(self):
|
def test_dify_cloud_returns_cloud(self):
|
||||||
"""Dify cloud domains should be categorized as CLOUD."""
|
"""Dify cloud domains should be categorized as CLOUD."""
|
||||||
assert get_runner_category('test', 'https://api.dify.ai/v1') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://api.dify.ai/v1") == RunnerCategory.CLOUD
|
||||||
assert get_runner_category('test', 'https://cloud.dify.ai') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://cloud.dify.ai") == RunnerCategory.CLOUD
|
||||||
|
|
||||||
def test_coze_cloud_returns_cloud(self):
|
def test_coze_cloud_returns_cloud(self):
|
||||||
"""Coze domains should be categorized as CLOUD."""
|
"""Coze domains should be categorized as CLOUD."""
|
||||||
assert get_runner_category('test', 'https://api.coze.com') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://api.coze.com") == RunnerCategory.CLOUD
|
||||||
assert get_runner_category('test', 'https://api.coze.cn') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://api.coze.cn") == RunnerCategory.CLOUD
|
||||||
|
|
||||||
def test_langflow_cloud_returns_cloud(self):
|
def test_langflow_cloud_returns_cloud(self):
|
||||||
"""Langflow domains should be categorized as CLOUD."""
|
"""Langflow domains should be categorized as CLOUD."""
|
||||||
assert get_runner_category('test', 'https://cloud.langflow.ai') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://cloud.langflow.ai") == RunnerCategory.CLOUD
|
||||||
assert get_runner_category('test', 'https://test.langflow.org') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://test.langflow.org") == RunnerCategory.CLOUD
|
||||||
|
|
||||||
def test_other_url_returns_cloud(self):
|
def test_other_url_returns_cloud(self):
|
||||||
"""Other URLs should default to CLOUD category."""
|
"""Other URLs should default to CLOUD category."""
|
||||||
assert get_runner_category('test', 'https://example.com') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD
|
||||||
assert get_runner_category('test', 'https://myserver.example.org') == RunnerCategory.CLOUD
|
assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'runner_url',
|
'runner_url',
|
||||||
@@ -101,7 +101,7 @@ class TestGetRunnerCategory:
|
|||||||
)
|
)
|
||||||
def test_invalid_urls_return_unknown(self, runner_url):
|
def test_invalid_urls_return_unknown(self, runner_url):
|
||||||
"""Invalid or incomplete URLs should return UNKNOWN."""
|
"""Invalid or incomplete URLs should return UNKNOWN."""
|
||||||
assert get_runner_category('test', runner_url) == RunnerCategory.UNKNOWN
|
assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
def test_urlparse_exception_returns_unknown(self):
|
def test_urlparse_exception_returns_unknown(self):
|
||||||
"""Exception during URL parsing should return UNKNOWN."""
|
"""Exception during URL parsing should return UNKNOWN."""
|
||||||
@@ -109,15 +109,15 @@ class TestGetRunnerCategory:
|
|||||||
from langbot.pkg.utils import runner
|
from langbot.pkg.utils import runner
|
||||||
|
|
||||||
def mock_urlparse(url):
|
def mock_urlparse(url):
|
||||||
raise Exception('URL parsing failed')
|
raise Exception("URL parsing failed")
|
||||||
|
|
||||||
with patch('langbot.pkg.utils.runner.urlparse', side_effect=mock_urlparse):
|
with patch("langbot.pkg.utils.runner.urlparse", side_effect=mock_urlparse):
|
||||||
result = runner.get_runner_category('test', 'http://example.com')
|
result = runner.get_runner_category("test", "http://example.com")
|
||||||
assert result == RunnerCategory.UNKNOWN
|
assert result == RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
def test_url_without_scheme_returns_unknown(self):
|
def test_url_without_scheme_returns_unknown(self):
|
||||||
"""URL without scheme should return UNKNOWN."""
|
"""URL without scheme should return UNKNOWN."""
|
||||||
assert get_runner_category('test', 'example.com') == RunnerCategory.UNKNOWN
|
assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'runner_url',
|
'runner_url',
|
||||||
@@ -146,21 +146,20 @@ class TestGetRunnerCategory:
|
|||||||
"""Domain names that only look like private IP prefixes should not be LOCAL."""
|
"""Domain names that only look like private IP prefixes should not be LOCAL."""
|
||||||
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
|
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
|
||||||
|
|
||||||
|
|
||||||
class TestIsCloudRunner:
|
class TestIsCloudRunner:
|
||||||
"""Test is_cloud_runner helper function."""
|
"""Test is_cloud_runner helper function."""
|
||||||
|
|
||||||
def test_cloud_runner_returns_true(self):
|
def test_cloud_runner_returns_true(self):
|
||||||
"""Cloud URL should return True."""
|
"""Cloud URL should return True."""
|
||||||
assert is_cloud_runner('test', 'https://api.dify.ai') is True
|
assert is_cloud_runner("test", "https://api.dify.ai") is True
|
||||||
|
|
||||||
def test_local_runner_returns_false(self):
|
def test_local_runner_returns_false(self):
|
||||||
"""Local URL should return False."""
|
"""Local URL should return False."""
|
||||||
assert is_cloud_runner('test', 'http://localhost:3000') is False
|
assert is_cloud_runner("test", "http://localhost:3000") is False
|
||||||
|
|
||||||
def test_unknown_returns_false(self):
|
def test_unknown_returns_false(self):
|
||||||
"""Unknown category should return False."""
|
"""Unknown category should return False."""
|
||||||
assert is_cloud_runner('test', None) is False
|
assert is_cloud_runner("test", None) is False
|
||||||
|
|
||||||
|
|
||||||
class TestIsLocalRunner:
|
class TestIsLocalRunner:
|
||||||
@@ -168,15 +167,15 @@ class TestIsLocalRunner:
|
|||||||
|
|
||||||
def test_local_runner_returns_true(self):
|
def test_local_runner_returns_true(self):
|
||||||
"""Local URL should return True."""
|
"""Local URL should return True."""
|
||||||
assert is_local_runner('test', 'http://localhost:3000') is True
|
assert is_local_runner("test", "http://localhost:3000") is True
|
||||||
|
|
||||||
def test_cloud_runner_returns_false(self):
|
def test_cloud_runner_returns_false(self):
|
||||||
"""Cloud URL should return False."""
|
"""Cloud URL should return False."""
|
||||||
assert is_local_runner('test', 'https://api.dify.ai') is False
|
assert is_local_runner("test", "https://api.dify.ai") is False
|
||||||
|
|
||||||
def test_unknown_returns_false(self):
|
def test_unknown_returns_false(self):
|
||||||
"""Unknown category should return False."""
|
"""Unknown category should return False."""
|
||||||
assert is_local_runner('test', None) is False
|
assert is_local_runner("test", None) is False
|
||||||
|
|
||||||
|
|
||||||
class TestGetRunnerInfo:
|
class TestGetRunnerInfo:
|
||||||
@@ -184,17 +183,17 @@ class TestGetRunnerInfo:
|
|||||||
|
|
||||||
def test_returns_dict_with_expected_keys(self):
|
def test_returns_dict_with_expected_keys(self):
|
||||||
"""Should return dict with name, url, and category keys."""
|
"""Should return dict with name, url, and category keys."""
|
||||||
info = get_runner_info('my-runner', 'http://localhost:3000')
|
info = get_runner_info("my-runner", "http://localhost:3000")
|
||||||
assert 'name' in info
|
assert "name" in info
|
||||||
assert 'url' in info
|
assert "url" in info
|
||||||
assert 'category' in info
|
assert "category" in info
|
||||||
|
|
||||||
def test_includes_correct_values(self):
|
def test_includes_correct_values(self):
|
||||||
"""Should include correct values in dict."""
|
"""Should include correct values in dict."""
|
||||||
info = get_runner_info('my-runner', 'http://localhost:3000')
|
info = get_runner_info("my-runner", "http://localhost:3000")
|
||||||
assert info['name'] == 'my-runner'
|
assert info["name"] == "my-runner"
|
||||||
assert info['url'] == 'http://localhost:3000'
|
assert info["url"] == "http://localhost:3000"
|
||||||
assert info['category'] == RunnerCategory.LOCAL
|
assert info["category"] == RunnerCategory.LOCAL
|
||||||
|
|
||||||
|
|
||||||
class TestExtractRunnerUrl:
|
class TestExtractRunnerUrl:
|
||||||
@@ -204,58 +203,74 @@ class TestExtractRunnerUrl:
|
|||||||
"""Should extract base-url from dify-service-api config."""
|
"""Should extract base-url from dify-service-api config."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
|
pipeline_config = {
|
||||||
url = extract_runner_url('dify-service-api', runner, pipeline_config)
|
"ai": {
|
||||||
assert url == 'https://api.dify.ai'
|
"dify-service-api": {"base-url": "https://api.dify.ai"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
url = extract_runner_url("dify-service-api", runner, pipeline_config)
|
||||||
|
assert url == "https://api.dify.ai"
|
||||||
|
|
||||||
def test_n8n_service_api_extracts_url(self):
|
def test_n8n_service_api_extracts_url(self):
|
||||||
"""Should extract webhook-url from n8n-service-api config."""
|
"""Should extract webhook-url from n8n-service-api config."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {'ai': {'n8n-service-api': {'webhook-url': 'https://my.n8n.cloud/webhook'}}}
|
pipeline_config = {
|
||||||
url = extract_runner_url('n8n-service-api', runner, pipeline_config)
|
"ai": {
|
||||||
assert url == 'https://my.n8n.cloud/webhook'
|
"n8n-service-api": {"webhook-url": "https://my.n8n.cloud/webhook"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
url = extract_runner_url("n8n-service-api", runner, pipeline_config)
|
||||||
|
assert url == "https://my.n8n.cloud/webhook"
|
||||||
|
|
||||||
def test_coze_api_extracts_url(self):
|
def test_coze_api_extracts_url(self):
|
||||||
"""Should extract api-base from coze-api config."""
|
"""Should extract api-base from coze-api config."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {'ai': {'coze-api': {'api-base': 'https://api.coze.com'}}}
|
pipeline_config = {
|
||||||
url = extract_runner_url('coze-api', runner, pipeline_config)
|
"ai": {
|
||||||
assert url == 'https://api.coze.com'
|
"coze-api": {"api-base": "https://api.coze.com"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
url = extract_runner_url("coze-api", runner, pipeline_config)
|
||||||
|
assert url == "https://api.coze.com"
|
||||||
|
|
||||||
def test_langflow_api_extracts_url(self):
|
def test_langflow_api_extracts_url(self):
|
||||||
"""Should extract base-url from langflow-api config."""
|
"""Should extract base-url from langflow-api config."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {'ai': {'langflow-api': {'base-url': 'https://cloud.langflow.ai'}}}
|
pipeline_config = {
|
||||||
url = extract_runner_url('langflow-api', runner, pipeline_config)
|
"ai": {
|
||||||
assert url == 'https://cloud.langflow.ai'
|
"langflow-api": {"base-url": "https://cloud.langflow.ai"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
url = extract_runner_url("langflow-api", runner, pipeline_config)
|
||||||
|
assert url == "https://cloud.langflow.ai"
|
||||||
|
|
||||||
def test_unknown_runner_returns_none(self):
|
def test_unknown_runner_returns_none(self):
|
||||||
"""Unknown runner name should return None."""
|
"""Unknown runner name should return None."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {}
|
pipeline_config = {}
|
||||||
url = extract_runner_url('unknown-runner', runner, pipeline_config)
|
url = extract_runner_url("unknown-runner", runner, pipeline_config)
|
||||||
assert url is None
|
assert url is None
|
||||||
|
|
||||||
def test_none_runner_returns_none(self):
|
def test_none_runner_returns_none(self):
|
||||||
"""None runner should return None."""
|
"""None runner should return None."""
|
||||||
url = extract_runner_url('test', None, {})
|
url = extract_runner_url("test", None, {})
|
||||||
assert url is None
|
assert url is None
|
||||||
|
|
||||||
def test_runner_without_pipeline_config_returns_none(self):
|
def test_runner_without_pipeline_config_returns_none(self):
|
||||||
"""Runner without pipeline_config attribute should return None."""
|
"""Runner without pipeline_config attribute should return None."""
|
||||||
runner = Mock(spec=[]) # Empty spec means no attributes
|
runner = Mock(spec=[]) # Empty spec means no attributes
|
||||||
url = extract_runner_url('test', runner, {})
|
url = extract_runner_url("test", runner, {})
|
||||||
assert url is None
|
assert url is None
|
||||||
|
|
||||||
def test_none_pipeline_config_returns_none(self):
|
def test_none_pipeline_config_returns_none(self):
|
||||||
"""None pipeline_config should return None."""
|
"""None pipeline_config should return None."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
url = extract_runner_url('dify-service-api', runner, None)
|
url = extract_runner_url("dify-service-api", runner, None)
|
||||||
assert url is None
|
assert url is None
|
||||||
|
|
||||||
def test_missing_ai_config_returns_none(self):
|
def test_missing_ai_config_returns_none(self):
|
||||||
@@ -263,7 +278,7 @@ class TestExtractRunnerUrl:
|
|||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {}
|
pipeline_config = {}
|
||||||
url = extract_runner_url('dify-service-api', runner, pipeline_config)
|
url = extract_runner_url("dify-service-api", runner, pipeline_config)
|
||||||
assert url is None
|
assert url is None
|
||||||
|
|
||||||
|
|
||||||
@@ -274,15 +289,19 @@ class TestGetRunnerCategoryFromRunner:
|
|||||||
"""Should extract URL and return correct category."""
|
"""Should extract URL and return correct category."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
|
pipeline_config = {
|
||||||
category = get_runner_category_from_runner('dify-service-api', runner, pipeline_config)
|
"ai": {
|
||||||
|
"dify-service-api": {"base-url": "https://api.dify.ai"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
category = get_runner_category_from_runner("dify-service-api", runner, pipeline_config)
|
||||||
assert category == RunnerCategory.CLOUD
|
assert category == RunnerCategory.CLOUD
|
||||||
|
|
||||||
def test_returns_unknown_for_missing_url(self):
|
def test_returns_unknown_for_missing_url(self):
|
||||||
"""Should return UNKNOWN when URL cannot be extracted."""
|
"""Should return UNKNOWN when URL cannot be extracted."""
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.pipeline_config = {}
|
runner.pipeline_config = {}
|
||||||
category = get_runner_category_from_runner('unknown', runner, {})
|
category = get_runner_category_from_runner("unknown", runner, {})
|
||||||
assert category == RunnerCategory.UNKNOWN
|
assert category == RunnerCategory.UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
@@ -291,9 +310,9 @@ class TestConstants:
|
|||||||
|
|
||||||
def test_runner_category_constants(self):
|
def test_runner_category_constants(self):
|
||||||
"""RunnerCategory should have LOCAL, CLOUD, UNKNOWN."""
|
"""RunnerCategory should have LOCAL, CLOUD, UNKNOWN."""
|
||||||
assert RunnerCategory.LOCAL == 'local'
|
assert RunnerCategory.LOCAL == "local"
|
||||||
assert RunnerCategory.CLOUD == 'cloud'
|
assert RunnerCategory.CLOUD == "cloud"
|
||||||
assert RunnerCategory.UNKNOWN == 'unknown'
|
assert RunnerCategory.UNKNOWN == "unknown"
|
||||||
|
|
||||||
def test_cloud_domains_not_empty(self):
|
def test_cloud_domains_not_empty(self):
|
||||||
"""CLOUD_DOMAINS should not be empty."""
|
"""CLOUD_DOMAINS should not be empty."""
|
||||||
@@ -304,5 +323,5 @@ class TestConstants:
|
|||||||
assert len(LOCAL_PATTERNS) > 0
|
assert len(LOCAL_PATTERNS) > 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -68,7 +68,11 @@ class TestNormalizeFilter:
|
|||||||
|
|
||||||
def test_normalize_filter_multiple_conditions(self):
|
def test_normalize_filter_multiple_conditions(self):
|
||||||
"""Multiple top-level keys are AND-ed (returned as multiple triples)."""
|
"""Multiple top-level keys are AND-ed (returned as multiple triples)."""
|
||||||
result = normalize_filter({'file_id': 'abc', 'status': {'$ne': 'deleted'}, 'created_at': {'$gte': 1700000000}})
|
result = normalize_filter({
|
||||||
|
'file_id': 'abc',
|
||||||
|
'status': {'$ne': 'deleted'},
|
||||||
|
'created_at': {'$gte': 1700000000}
|
||||||
|
})
|
||||||
|
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
# Order should match dict iteration order
|
# Order should match dict iteration order
|
||||||
@@ -145,7 +149,11 @@ class TestStripUnsupportedFields:
|
|||||||
('file_id', '$eq', 'def'),
|
('file_id', '$eq', 'def'),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = strip_unsupported_fields(triples, {'file_id', 'chunk_uuid'}, field_aliases={'uuid': 'chunk_uuid'})
|
result = strip_unsupported_fields(
|
||||||
|
triples,
|
||||||
|
{'file_id', 'chunk_uuid'},
|
||||||
|
field_aliases={'uuid': 'chunk_uuid'}
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
# 'uuid' should be resolved to 'chunk_uuid'
|
# 'uuid' should be resolved to 'chunk_uuid'
|
||||||
@@ -161,7 +169,7 @@ class TestStripUnsupportedFields:
|
|||||||
result = strip_unsupported_fields(
|
result = strip_unsupported_fields(
|
||||||
triples,
|
triples,
|
||||||
{'file_id'}, # chunk_uuid not supported
|
{'file_id'}, # chunk_uuid not supported
|
||||||
field_aliases={'uuid': 'chunk_uuid'},
|
field_aliases={'uuid': 'chunk_uuid'}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == []
|
assert result == []
|
||||||
@@ -199,5 +207,4 @@ class TestSupportedOpsConstant:
|
|||||||
def test_supported_ops_is_frozenset(self):
|
def test_supported_ops_is_frozenset(self):
|
||||||
"""SUPPORTED_OPS is a frozenset for immutability."""
|
"""SUPPORTED_OPS is a frozenset for immutability."""
|
||||||
from collections.abc import Set
|
from collections.abc import Set
|
||||||
|
assert isinstance(SUPPORTED_OPS, Set)
|
||||||
assert isinstance(SUPPORTED_OPS, Set)
|
|
||||||
@@ -55,7 +55,6 @@ class TestVectorDBManagerInitialization:
|
|||||||
|
|
||||||
# Run initialize synchronously for test
|
# Run initialize synchronously for test
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
# Chroma should be instantiated
|
# Chroma should be instantiated
|
||||||
@@ -77,7 +76,6 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_chroma_class.assert_called_once_with(mock_app)
|
mock_chroma_class.assert_called_once_with(mock_app)
|
||||||
@@ -98,7 +96,6 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_qdrant_class.assert_called_once_with(mock_app)
|
mock_qdrant_class.assert_called_once_with(mock_app)
|
||||||
@@ -118,7 +115,6 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_seekdb_class.assert_called_once_with(mock_app)
|
mock_seekdb_class.assert_called_once_with(mock_app)
|
||||||
@@ -127,7 +123,11 @@ class TestVectorDBManagerInitialization:
|
|||||||
"""Milvus config with custom URI."""
|
"""Milvus config with custom URI."""
|
||||||
vdb_config = {
|
vdb_config = {
|
||||||
'use': 'milvus',
|
'use': 'milvus',
|
||||||
'milvus': {'uri': 'http://localhost:19530', 'token': 'root:Milvus', 'db_name': 'langbot_db'},
|
'milvus': {
|
||||||
|
'uri': 'http://localhost:19530',
|
||||||
|
'token': 'root:Milvus',
|
||||||
|
'db_name': 'langbot_db'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
mock_app = self._create_mock_app(vdb_config)
|
mock_app = self._create_mock_app(vdb_config)
|
||||||
|
|
||||||
@@ -141,11 +141,13 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_milvus_class.assert_called_once_with(
|
mock_milvus_class.assert_called_once_with(
|
||||||
mock_app, uri='http://localhost:19530', token='root:Milvus', db_name='langbot_db'
|
mock_app,
|
||||||
|
uri='http://localhost:19530',
|
||||||
|
token='root:Milvus',
|
||||||
|
db_name='langbot_db'
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_milvus_backend_defaults(self):
|
def test_initialize_milvus_backend_defaults(self):
|
||||||
@@ -163,45 +165,23 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
# Should use default values
|
# Should use default values
|
||||||
mock_milvus_class.assert_called_once_with(mock_app, uri='./data/milvus.db', token=None, db_name='default')
|
mock_milvus_class.assert_called_once_with(
|
||||||
|
mock_app,
|
||||||
|
uri='./data/milvus.db',
|
||||||
|
token=None,
|
||||||
|
db_name='default'
|
||||||
|
)
|
||||||
|
|
||||||
def test_initialize_pgvector_with_connection_string(self):
|
def test_initialize_pgvector_with_connection_string(self):
|
||||||
"""pgvector with connection string."""
|
"""pgvector with connection string."""
|
||||||
vdb_config = {'use': 'pgvector', 'pgvector': {'connection_string': 'postgresql://user:pass@host:5432/langbot'}}
|
|
||||||
mock_app = self._create_mock_app(vdb_config)
|
|
||||||
|
|
||||||
mocks = self._make_vector_import_mocks()
|
|
||||||
mock_pgvector_class = MagicMock()
|
|
||||||
mocks['langbot.pkg.vector.vdbs.pgvector_db'].PgVectorDatabase = mock_pgvector_class
|
|
||||||
|
|
||||||
with isolated_sys_modules(mocks):
|
|
||||||
from langbot.pkg.vector.mgr import VectorDBManager
|
|
||||||
|
|
||||||
mgr = VectorDBManager(mock_app)
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
|
||||||
|
|
||||||
mock_pgvector_class.assert_called_once_with(
|
|
||||||
mock_app, connection_string='postgresql://user:pass@host:5432/langbot'
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_initialize_pgvector_with_individual_params(self):
|
|
||||||
"""pgvector with individual connection parameters."""
|
|
||||||
vdb_config = {
|
vdb_config = {
|
||||||
'use': 'pgvector',
|
'use': 'pgvector',
|
||||||
'pgvector': {
|
'pgvector': {
|
||||||
'host': 'db.example.com',
|
'connection_string': 'postgresql://user:pass@host:5432/langbot'
|
||||||
'port': 5433,
|
}
|
||||||
'database': 'vectordb',
|
|
||||||
'user': 'admin',
|
|
||||||
'password': 'secret',
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
mock_app = self._create_mock_app(vdb_config)
|
mock_app = self._create_mock_app(vdb_config)
|
||||||
|
|
||||||
@@ -215,11 +195,46 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_pgvector_class.assert_called_once_with(
|
mock_pgvector_class.assert_called_once_with(
|
||||||
mock_app, host='db.example.com', port=5433, database='vectordb', user='admin', password='secret'
|
mock_app,
|
||||||
|
connection_string='postgresql://user:pass@host:5432/langbot'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialize_pgvector_with_individual_params(self):
|
||||||
|
"""pgvector with individual connection parameters."""
|
||||||
|
vdb_config = {
|
||||||
|
'use': 'pgvector',
|
||||||
|
'pgvector': {
|
||||||
|
'host': 'db.example.com',
|
||||||
|
'port': 5433,
|
||||||
|
'database': 'vectordb',
|
||||||
|
'user': 'admin',
|
||||||
|
'password': 'secret'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_app = self._create_mock_app(vdb_config)
|
||||||
|
|
||||||
|
mocks = self._make_vector_import_mocks()
|
||||||
|
mock_pgvector_class = MagicMock()
|
||||||
|
mocks['langbot.pkg.vector.vdbs.pgvector_db'].PgVectorDatabase = mock_pgvector_class
|
||||||
|
|
||||||
|
with isolated_sys_modules(mocks):
|
||||||
|
from langbot.pkg.vector.mgr import VectorDBManager
|
||||||
|
|
||||||
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
|
mock_pgvector_class.assert_called_once_with(
|
||||||
|
mock_app,
|
||||||
|
host='db.example.com',
|
||||||
|
port=5433,
|
||||||
|
database='vectordb',
|
||||||
|
user='admin',
|
||||||
|
password='secret'
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_pgvector_defaults(self):
|
def test_initialize_pgvector_defaults(self):
|
||||||
@@ -237,11 +252,15 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_pgvector_class.assert_called_once_with(
|
mock_pgvector_class.assert_called_once_with(
|
||||||
mock_app, host='localhost', port=5432, database='langbot', user='postgres', password='postgres'
|
mock_app,
|
||||||
|
host='localhost',
|
||||||
|
port=5432,
|
||||||
|
database='langbot',
|
||||||
|
user='postgres',
|
||||||
|
password='postgres'
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_unknown_backend_defaults_to_chroma(self):
|
def test_initialize_unknown_backend_defaults_to_chroma(self):
|
||||||
@@ -259,7 +278,6 @@ class TestVectorDBManagerInitialization:
|
|||||||
mgr = VectorDBManager(mock_app)
|
mgr = VectorDBManager(mock_app)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
asyncio.get_event_loop().run_until_complete(mgr.initialize())
|
||||||
|
|
||||||
mock_chroma_class.assert_called_once_with(mock_app)
|
mock_chroma_class.assert_called_once_with(mock_app)
|
||||||
@@ -317,4 +335,4 @@ class TestVectorDBManagerProxies:
|
|||||||
mgr.vector_db = mock_vector_db
|
mgr.vector_db = mock_vector_db
|
||||||
|
|
||||||
result = mgr.get_supported_search_types()
|
result = mgr.get_supported_search_types()
|
||||||
assert result == ['vector', 'full_text']
|
assert result == ['vector', 'full_text']
|
||||||
@@ -39,7 +39,6 @@ class TestVectorDatabaseAbstractMethods:
|
|||||||
|
|
||||||
def test_abstract_methods_required(self):
|
def test_abstract_methods_required(self):
|
||||||
"""Subclass must implement all abstract methods."""
|
"""Subclass must implement all abstract methods."""
|
||||||
|
|
||||||
class IncompleteVectorDB(VectorDatabase):
|
class IncompleteVectorDB(VectorDatabase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -48,21 +47,11 @@ class TestVectorDatabaseAbstractMethods:
|
|||||||
|
|
||||||
def test_supported_search_types_default(self):
|
def test_supported_search_types_default(self):
|
||||||
"""Default supported_search_types returns [VECTOR]."""
|
"""Default supported_search_types returns [VECTOR]."""
|
||||||
|
|
||||||
class MinimalVectorDB(VectorDatabase):
|
class MinimalVectorDB(VectorDatabase):
|
||||||
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def search(
|
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
|
||||||
self,
|
|
||||||
collection,
|
|
||||||
query_embedding,
|
|
||||||
k=5,
|
|
||||||
search_type='vector',
|
|
||||||
query_text='',
|
|
||||||
filter=None,
|
|
||||||
vector_weight=None,
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete_by_file_id(self, collection, file_id):
|
async def delete_by_file_id(self, collection, file_id):
|
||||||
@@ -82,21 +71,11 @@ class TestVectorDatabaseAbstractMethods:
|
|||||||
|
|
||||||
def test_list_by_filter_default_implementation(self):
|
def test_list_by_filter_default_implementation(self):
|
||||||
"""list_by_filter has default implementation returning empty."""
|
"""list_by_filter has default implementation returning empty."""
|
||||||
|
|
||||||
class MinimalVectorDB(VectorDatabase):
|
class MinimalVectorDB(VectorDatabase):
|
||||||
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def search(
|
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
|
||||||
self,
|
|
||||||
collection,
|
|
||||||
query_embedding,
|
|
||||||
k=5,
|
|
||||||
search_type='vector',
|
|
||||||
query_text='',
|
|
||||||
filter=None,
|
|
||||||
vector_weight=None,
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete_by_file_id(self, collection, file_id):
|
async def delete_by_file_id(self, collection, file_id):
|
||||||
@@ -114,8 +93,9 @@ class TestVectorDatabaseAbstractMethods:
|
|||||||
db = MinimalVectorDB()
|
db = MinimalVectorDB()
|
||||||
# list_by_filter should return empty list and -1 for total
|
# list_by_filter should return empty list and -1 for total
|
||||||
import asyncio
|
import asyncio
|
||||||
|
result = asyncio.get_event_loop().run_until_complete(
|
||||||
result = asyncio.get_event_loop().run_until_complete(db.list_by_filter('test_collection'))
|
db.list_by_filter('test_collection')
|
||||||
|
)
|
||||||
assert result == ([], -1)
|
assert result == ([], -1)
|
||||||
|
|
||||||
|
|
||||||
@@ -125,17 +105,14 @@ class TestVectorDatabaseInterface:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_vector_db(self):
|
def mock_vector_db(self):
|
||||||
"""Create a minimal mock VectorDatabase for testing."""
|
"""Create a minimal mock VectorDatabase for testing."""
|
||||||
|
|
||||||
class MockVectorDB(VectorDatabase):
|
class MockVectorDB(VectorDatabase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.add_embeddings = AsyncMock()
|
self.add_embeddings = AsyncMock()
|
||||||
self.search = AsyncMock(
|
self.search = AsyncMock(return_value={
|
||||||
return_value={
|
'ids': [['id1', 'id2']],
|
||||||
'ids': [['id1', 'id2']],
|
'distances': [[0.1, 0.2]],
|
||||||
'distances': [[0.1, 0.2]],
|
'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]]
|
||||||
'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]],
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
self.delete_by_file_id = AsyncMock()
|
self.delete_by_file_id = AsyncMock()
|
||||||
self.delete_by_filter = AsyncMock(return_value=5)
|
self.delete_by_filter = AsyncMock(return_value=5)
|
||||||
self.get_or_create_collection = AsyncMock()
|
self.get_or_create_collection = AsyncMock()
|
||||||
@@ -144,16 +121,7 @@ class TestVectorDatabaseInterface:
|
|||||||
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def search(
|
async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None):
|
||||||
self,
|
|
||||||
collection,
|
|
||||||
query_embedding,
|
|
||||||
k=5,
|
|
||||||
search_type='vector',
|
|
||||||
query_text='',
|
|
||||||
filter=None,
|
|
||||||
vector_weight=None,
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete_by_file_id(self, collection, file_id):
|
async def delete_by_file_id(self, collection, file_id):
|
||||||
@@ -178,7 +146,7 @@ class TestVectorDatabaseInterface:
|
|||||||
ids=['id1', 'id2'],
|
ids=['id1', 'id2'],
|
||||||
embeddings_list=[[0.1, 0.2], [0.3, 0.4]],
|
embeddings_list=[[0.1, 0.2], [0.3, 0.4]],
|
||||||
metadatas=[{'a': 1}, {'b': 2}],
|
metadatas=[{'a': 1}, {'b': 2}],
|
||||||
documents=['doc1', 'doc2'],
|
documents=['doc1', 'doc2']
|
||||||
)
|
)
|
||||||
mock_vector_db.add_embeddings.assert_called_once()
|
mock_vector_db.add_embeddings.assert_called_once()
|
||||||
|
|
||||||
@@ -194,7 +162,7 @@ class TestVectorDatabaseInterface:
|
|||||||
search_type='hybrid',
|
search_type='hybrid',
|
||||||
query_text='search text',
|
query_text='search text',
|
||||||
filter={'file_id': 'abc'},
|
filter={'file_id': 'abc'},
|
||||||
vector_weight=0.7,
|
vector_weight=0.7
|
||||||
)
|
)
|
||||||
mock_vector_db.search.assert_called_once()
|
mock_vector_db.search.assert_called_once()
|
||||||
|
|
||||||
@@ -202,4 +170,4 @@ class TestVectorDatabaseInterface:
|
|||||||
async def test_delete_by_filter_returns_int(self, mock_vector_db):
|
async def test_delete_by_filter_returns_int(self, mock_vector_db):
|
||||||
"""delete_by_filter returns int count."""
|
"""delete_by_filter returns int count."""
|
||||||
result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'})
|
result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'})
|
||||||
assert isinstance(result, int)
|
assert isinstance(result, int)
|
||||||
@@ -5,7 +5,6 @@ Tests cover:
|
|||||||
- _build_milvus_expr: Milvus boolean expression string conversion
|
- _build_milvus_expr: Milvus boolean expression string conversion
|
||||||
- _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion
|
- _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@@ -123,13 +122,11 @@ class TestQdrantFilterConversion:
|
|||||||
"""Multiple conditions are combined in must/must_not."""
|
"""Multiple conditions are combined in must/must_not."""
|
||||||
qdrant_module = get_qdrant_module()
|
qdrant_module = get_qdrant_module()
|
||||||
|
|
||||||
result = qdrant_module._build_qdrant_filter(
|
result = qdrant_module._build_qdrant_filter({
|
||||||
{
|
'file_id': 'abc',
|
||||||
'file_id': 'abc',
|
'status': {'$ne': 'deleted'},
|
||||||
'status': {'$ne': 'deleted'},
|
'created_at': {'$gte': 100},
|
||||||
'created_at': {'$gte': 100},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result.must) == 2 # file_id eq + created_at gte
|
assert len(result.must) == 2 # file_id eq + created_at gte
|
||||||
assert len(result.must_not) == 1 # status ne
|
assert len(result.must_not) == 1 # status ne
|
||||||
@@ -201,12 +198,10 @@ class TestMilvusFilterConversion:
|
|||||||
"""Multiple conditions are joined with 'and'."""
|
"""Multiple conditions are joined with 'and'."""
|
||||||
milvus_module = get_milvus_module()
|
milvus_module = get_milvus_module()
|
||||||
|
|
||||||
result = milvus_module._build_milvus_expr(
|
result = milvus_module._build_milvus_expr({
|
||||||
{
|
'file_id': 'abc',
|
||||||
'file_id': 'abc',
|
'chunk_uuid': {'$ne': 'def'},
|
||||||
'chunk_uuid': {'$ne': 'def'},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
assert 'and' in result
|
assert 'and' in result
|
||||||
assert 'file_id == "abc"' in result
|
assert 'file_id == "abc"' in result
|
||||||
assert 'chunk_uuid != "def"' in result
|
assert 'chunk_uuid != "def"' in result
|
||||||
@@ -277,7 +272,6 @@ class TestPgVectorFilterConversion:
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
# Verify it's a SQLAlchemy BinaryExpression
|
# Verify it's a SQLAlchemy BinaryExpression
|
||||||
from sqlalchemy.sql.expression import BinaryExpression
|
from sqlalchemy.sql.expression import BinaryExpression
|
||||||
|
|
||||||
assert isinstance(result[0], BinaryExpression)
|
assert isinstance(result[0], BinaryExpression)
|
||||||
|
|
||||||
def test_ne_operator_creates_inequality_condition(self):
|
def test_ne_operator_creates_inequality_condition(self):
|
||||||
@@ -327,12 +321,10 @@ class TestPgVectorFilterConversion:
|
|||||||
"""Multiple conditions return list of conditions."""
|
"""Multiple conditions return list of conditions."""
|
||||||
pgvector_module = get_pgvector_module()
|
pgvector_module = get_pgvector_module()
|
||||||
|
|
||||||
result = pgvector_module._build_pg_conditions(
|
result = pgvector_module._build_pg_conditions({
|
||||||
{
|
'file_id': 'abc',
|
||||||
'file_id': 'abc',
|
'chunk_uuid': {'$ne': 'def'},
|
||||||
'chunk_uuid': {'$ne': 'def'},
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
@@ -357,13 +349,11 @@ class TestPgVectorFilterConversion:
|
|||||||
"""Only supported fields (text, file_id, chunk_uuid) are kept."""
|
"""Only supported fields (text, file_id, chunk_uuid) are kept."""
|
||||||
pgvector_module = get_pgvector_module()
|
pgvector_module = get_pgvector_module()
|
||||||
|
|
||||||
result = pgvector_module._build_pg_conditions(
|
result = pgvector_module._build_pg_conditions({
|
||||||
{
|
'text': {'$ne': ''},
|
||||||
'text': {'$ne': ''},
|
'file_id': 'abc',
|
||||||
'file_id': 'abc',
|
'chunk_uuid': {'$in': ['x', 'y']},
|
||||||
'chunk_uuid': {'$in': ['x', 'y']},
|
'unsupported': 'value',
|
||||||
'unsupported': 'value',
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 3 # Only supported fields
|
assert len(result) == 3 # Only supported fields
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
"""
|
"""
|
||||||
Test utilities package.
|
Test utilities package.
|
||||||
"""
|
"""
|
||||||
@@ -26,7 +26,6 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
class MockLifecycleControlScope(enum.Enum):
|
class MockLifecycleControlScope(enum.Enum):
|
||||||
"""Mock enum for breaking circular import in core.entities."""
|
"""Mock enum for breaking circular import in core.entities."""
|
||||||
|
|
||||||
APPLICATION = 'application'
|
APPLICATION = 'application'
|
||||||
PLATFORM = 'platform'
|
PLATFORM = 'platform'
|
||||||
PLUGIN = 'plugin'
|
PLUGIN = 'plugin'
|
||||||
@@ -191,4 +190,4 @@ def get_handler_modules_to_clear(handler_name: str) -> list[str]:
|
|||||||
'langbot.pkg.pipeline.process.handler',
|
'langbot.pkg.pipeline.process.handler',
|
||||||
'langbot.pkg.pipeline.process.handlers',
|
'langbot.pkg.pipeline.process.handlers',
|
||||||
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
|
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
|
||||||
]
|
]
|
||||||
2
web/.gitignore
vendored
2
web/.gitignore
vendored
@@ -12,8 +12,6 @@
|
|||||||
|
|
||||||
# testing
|
# testing
|
||||||
/coverage
|
/coverage
|
||||||
/playwright-report
|
|
||||||
/test-results
|
|
||||||
|
|
||||||
# next.js
|
# next.js
|
||||||
/dist/
|
/dist/
|
||||||
|
|||||||
@@ -1,13 +1,3 @@
|
|||||||
# Debug LangBot Frontend
|
# Debug LangBot Frontend
|
||||||
|
|
||||||
Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information.
|
Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information.
|
||||||
|
|
||||||
## Tests
|
|
||||||
|
|
||||||
Run the frontend smoke tests without a backend process:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pnpm test:e2e
|
|
||||||
```
|
|
||||||
|
|
||||||
The Playwright suite starts Vite and mocks the LangBot backend and Space APIs.
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
"dev": "vite",
|
"dev": "vite",
|
||||||
"build": "tsc && vite build",
|
"build": "tsc && vite build",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"test:e2e": "playwright test",
|
|
||||||
"lint": "eslint .",
|
"lint": "eslint .",
|
||||||
"format": "prettier --write ."
|
"format": "prettier --write ."
|
||||||
},
|
},
|
||||||
@@ -87,7 +86,6 @@
|
|||||||
"zod": "^3.24.4"
|
"zod": "^3.24.4"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@playwright/test": "^1.61.0",
|
|
||||||
"@types/debug": "^4.1.12",
|
"@types/debug": "^4.1.12",
|
||||||
"@types/estree": "^1.0.8",
|
"@types/estree": "^1.0.8",
|
||||||
"@types/estree-jsx": "^1.0.5",
|
"@types/estree-jsx": "^1.0.5",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user