From 1ae5aacc00383ba8a30ef6e92572939b169bdbd5 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Tue, 16 Jun 2026 03:09:55 +0000 Subject: [PATCH 01/16] test: add frontend smoke and backend e2e CI (#2251) --- .github/workflows/frontend-tests.yml | 46 +++ .github/workflows/lint.yml | 2 +- .github/workflows/run-tests.yml | 63 +++- tests/README.md | 62 +++- tests/e2e/conftest.py | 4 +- tests/e2e/test_startup.py | 6 +- tests/e2e/utils/process_manager.py | 17 +- web/.gitignore | 2 + web/README.md | 10 + web/package.json | 2 + web/playwright.config.ts | 25 ++ web/pnpm-lock.yaml | 35 +++ web/tests/e2e/fixtures/langbot-api.ts | 417 ++++++++++++++++++++++++++ web/tests/e2e/home-smoke.spec.ts | 133 ++++++++ web/tests/e2e/login.spec.ts | 22 ++ 15 files changed, 835 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/frontend-tests.yml create mode 100644 web/playwright.config.ts create mode 100644 web/tests/e2e/fixtures/langbot-api.ts create mode 100644 web/tests/e2e/home-smoke.spec.ts create mode 100644 web/tests/e2e/login.spec.ts diff --git a/.github/workflows/frontend-tests.yml b/.github/workflows/frontend-tests.yml new file mode 100644 index 000000000..0265e0441 --- /dev/null +++ b/.github/workflows/frontend-tests.yml @@ -0,0 +1,46 @@ +name: Frontend Tests + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + paths: + - 'web/**' + - '.github/workflows/frontend-tests.yml' + push: + branches: + - master + - develop + paths: + - 'web/**' + - '.github/workflows/frontend-tests.yml' + +jobs: + playwright-smoke: + name: Playwright Smoke + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '25' + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 8.9.2 + + - name: Install dependencies + working-directory: web + run: pnpm install --frozen-lockfile + + - name: Install Playwright browsers + working-directory: web + run: pnpm exec playwright install --with-deps chromium + + - name: Run Playwright smoke tests + working-directory: web + run: pnpm test:e2e diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e1d89c1ef..f2baae7c7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,7 +29,7 @@ jobs: run: uv sync --dev - name: Run ruff check - run: uv run ruff check src + run: uv run ruff check src/langbot/ tests/ --output-format=concise - name: Run ruff format run: uv run ruff format src --check diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index c33684012..aaee59549 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -84,6 +84,67 @@ jobs: echo "" >> $GITHUB_STEP_SUMMARY echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY + e2e: + name: E2E Startup Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Install dependencies + run: uv sync --dev + + - name: Run E2E startup tests + run: uv run pytest tests/e2e -q --tb=short + + - name: E2E Test Summary + if: always() + run: | + echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY + + box-integration: + name: Box Integration Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Install dependencies + run: uv sync --dev + + - name: Check Docker runtime + run: docker info + + - name: Run Box integration tests + run: uv run pytest tests/integration_tests -q --tb=short + + - name: Box Integration Test Summary + if: always() + run: | + echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY + coverage: name: Coverage Gate runs-on: ubuntu-latest @@ -129,4 +190,4 @@ jobs: echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY - echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY \ No newline at end of file + echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY diff --git a/tests/README.md b/tests/README.md index e490ed5cf..3a110b9dc 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,6 +1,7 @@ # LangBot Test Suite -This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages. +This directory contains the LangBot backend test suite, including unit tests, +integration tests, startup E2E tests, and container-backed Box runtime tests. ## Quality Gate Layers @@ -10,10 +11,15 @@ LangBot uses a layered quality gate system for developers and CI: |-------|---------|--------------|-------------| | **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit | | **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly | +| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI | +| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI | +| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI | | **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI | | **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes | -**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows. +**Note**: PostgreSQL migration tests and slow tests are NOT in local default +gates. They run in separate CI workflows. Frontend Playwright tests live under +`web/tests/e2e` and are documented in `web/README.md`. ### Developer Workflow @@ -28,6 +34,9 @@ make test-all-local bash scripts/test-quick.sh # ~2 min bash scripts/test-integration-fast.sh # ~3 min bash scripts/test-coverage.sh # ~8 min +uv run --python 3.12 pytest tests/e2e -q --tb=short +uv run --python 3.12 pytest tests/integration_tests -q --tb=short +cd web && pnpm test:e2e ``` ### Coverage Baseline @@ -70,6 +79,12 @@ tests/ │ └── persistence/ # Database/persistence tests │ ├── __init__.py │ └── test_migrations.py # Alembic migration tests +├── e2e/ # Real LangBot startup E2E tests +│ ├── conftest.py +│ ├── test_startup.py +│ └── utils/ +├── integration_tests/ # Container-backed integration tests +│ └── box/ # Box runtime and MCP process tests ├── smoke/ # Smoke tests (quick validation) │ └── test_fake_message_flow.py ├── unit_tests/ # Unit tests @@ -303,6 +318,44 @@ These tests: - Test prevent_default, exception handling, and full message flow - Do not require real LLM provider keys +### Running backend E2E startup tests + +Backend E2E tests start a real LangBot process with a generated minimal +`data/config.yaml`, SQLite database, local storage, and embedded Chroma path. +They do not require provider keys or external services. + +```bash +uv run --python 3.12 pytest tests/e2e -q --tb=short +``` + +These tests verify startup orchestration, migrations, API route registration, +and the minimal no-LLM startup path. The E2E process manager disables ambient +proxy variables for subprocess startup and uses direct localhost HTTP clients, +so local proxy settings should not affect the health checks. + +### Running Box integration tests + +Box integration tests exercise the real sandbox runtime path, including command +execution, session persistence, managed process WebSocket attachment, and +cleanup behavior. + +```bash +uv run --python 3.12 pytest tests/integration_tests -q --tb=short +``` + +These tests require a working Docker or Podman runtime. In CI, the dedicated +Box integration job checks Docker availability before running the tests. + +### Running frontend E2E tests + +Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite +and mock the LangBot backend and Space APIs, so no backend process is required. + +```bash +cd web +pnpm test:e2e +``` + ### Known Issues Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues: @@ -320,6 +373,9 @@ Tests are automatically run on: - Push to master/develop branches The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility. +Startup E2E and Box integration tests run as separate Python 3.12 jobs because +they exercise process/container behavior instead of pure Python compatibility. +Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`. ## Adding New Tests @@ -406,4 +462,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun - [ ] Add E2E tests - [ ] Add performance benchmarks - [ ] Add mutation testing for better coverage quality -- [ ] Add property-based testing with Hypothesis \ No newline at end of file +- [ ] Add property-based testing with Hypothesis diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 200ac22a8..ddef1abdc 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process): base_url = f'http://127.0.0.1:{e2e_port}' - with httpx.Client(base_url=base_url, timeout=10.0) as client: + with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client: yield client @pytest.fixture(scope='session') def e2e_db_path(e2e_tmpdir): """Path to SQLite database file.""" - return e2e_tmpdir / 'data' / 'langbot.db' \ No newline at end of file + return e2e_tmpdir / 'data' / 'langbot.db' diff --git a/tests/e2e/test_startup.py b/tests/e2e/test_startup.py index dcbe8e75f..8954505a2 100644 --- a/tests/e2e/test_startup.py +++ b/tests/e2e/test_startup.py @@ -38,7 +38,7 @@ class TestStartupFlow: # System info should contain version info assert 'version' in data['data'] or 'edition' in data['data'] - def test_database_initialized(self, e2e_db_path): + def test_database_initialized(self, langbot_process, e2e_db_path): """Verify SQLite database was created and initialized.""" assert e2e_db_path.exists() @@ -75,7 +75,7 @@ class TestStartupFlow: """Test auth endpoint.""" # First startup may allow initial setup response = e2e_client.post('/api/v1/user/auth', json={ - 'username': 'admin', + 'user': 'admin', 'password': 'admin', }) @@ -94,7 +94,7 @@ class TestStartupStages: # If API responds on e2e_port, config was loaded assert e2e_client.get('/api/v1/system/info').status_code == 200 - def test_migrations_applied(self, e2e_db_path): + def test_migrations_applied(self, langbot_process, e2e_db_path): """Verify database migrations were applied.""" import sqlite3 conn = sqlite3.connect(str(e2e_db_path)) diff --git a/tests/e2e/utils/process_manager.py b/tests/e2e/utils/process_manager.py index 888b5dec8..840509874 100644 --- a/tests/e2e/utils/process_manager.py +++ b/tests/e2e/utils/process_manager.py @@ -44,6 +44,17 @@ class LangBotProcess: # Prepare environment env = os.environ.copy() env['PYTHONPATH'] = str(self.project_root / 'src') + for proxy_key in ( + 'HTTP_PROXY', + 'HTTPS_PROXY', + 'ALL_PROXY', + 'http_proxy', + 'https_proxy', + 'all_proxy', + ): + env.pop(proxy_key, None) + env['NO_PROXY'] = '127.0.0.1,localhost' + env['no_proxy'] = '127.0.0.1,localhost' # Set API port via environment variable env['API__PORT'] = str(self.port) @@ -113,6 +124,8 @@ precision = 2 r = httpx.get( f'http://127.0.0.1:{self.port}/api/v1/system/info', timeout=2.0, + follow_redirects=False, + trust_env=False, ) if r.status_code == 200: logger.info(f'LangBot started successfully on port {self.port}') @@ -185,6 +198,8 @@ precision = 2 r = httpx.get( f'http://127.0.0.1:{self.port}/api/v1/system/info', timeout=5.0, + follow_redirects=False, + trust_env=False, ) return r.status_code == 200 except Exception: @@ -201,4 +216,4 @@ def find_project_root() -> Path: return parent # Fallback to LangBot-test-build directory - return Path('/home/glwuy/langbot-app/LangBot-test-build') \ No newline at end of file + return Path('/home/glwuy/langbot-app/LangBot-test-build') diff --git a/web/.gitignore b/web/.gitignore index d50a18139..1326feb7c 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -12,6 +12,8 @@ # testing /coverage +/playwright-report +/test-results # next.js /dist/ diff --git a/web/README.md b/web/README.md index ae1de3cf5..b8ed72c66 100644 --- a/web/README.md +++ b/web/README.md @@ -1,3 +1,13 @@ # Debug LangBot Frontend Please refer to the [Development Guide](https://link.langbot.app/en/docs/dev-config) for more information. + +## Tests + +Run the frontend smoke tests without a backend process: + +```bash +pnpm test:e2e +``` + +The Playwright suite starts Vite and mocks the LangBot backend and Space APIs. diff --git a/web/package.json b/web/package.json index 555f10303..59f17f40e 100644 --- a/web/package.json +++ b/web/package.json @@ -6,6 +6,7 @@ "dev": "vite", "build": "tsc && vite build", "preview": "vite preview", + "test:e2e": "playwright test", "lint": "eslint .", "format": "prettier --write ." }, @@ -86,6 +87,7 @@ "zod": "^3.24.4" }, "devDependencies": { + "@playwright/test": "^1.61.0", "@types/debug": "^4.1.12", "@types/estree": "^1.0.8", "@types/estree-jsx": "^1.0.5", diff --git a/web/playwright.config.ts b/web/playwright.config.ts new file mode 100644 index 000000000..e15c6ef9e --- /dev/null +++ b/web/playwright.config.ts @@ -0,0 +1,25 @@ +import { defineConfig, devices } from '@playwright/test'; + +export default defineConfig({ + testDir: './tests/e2e', + fullyParallel: true, + forbidOnly: !!process.env.CI, + retries: process.env.CI ? 1 : 0, + reporter: process.env.CI ? [['github'], ['list']] : 'list', + use: { + baseURL: 'http://127.0.0.1:4173', + trace: 'on-first-retry', + }, + projects: [ + { + name: 'chromium', + use: { ...devices['Desktop Chrome'] }, + }, + ], + webServer: { + command: 'pnpm exec vite --host 127.0.0.1 --port 4173', + url: 'http://127.0.0.1:4173', + reuseExistingServer: !process.env.CI, + timeout: 120_000, + }, +}); diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index c001a7654..25660ff9c 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -192,6 +192,9 @@ dependencies: version: 3.25.76 devDependencies: + '@playwright/test': + specifier: ^1.61.0 + version: 1.61.0 '@types/debug': specifier: ^4.1.12 version: 4.1.12 @@ -529,6 +532,14 @@ packages: engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0} dev: true + /@playwright/test@1.61.0: + resolution: {integrity: sha512-cKA5B6lpFEMyMGjxF54QihfYpB4FkEGH+qZhtArDEG+wezQAJY8Pq6C7T1SjWz+FFzt3TbyoXBQYk/0292TdJA==} + engines: {node: '>=18'} + hasBin: true + dependencies: + playwright: 1.61.0 + dev: true + /@radix-ui/number@1.1.1: resolution: {integrity: sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==} dev: false @@ -3204,6 +3215,14 @@ packages: engines: {node: '>=0.4.x'} dev: false + /fsevents@2.3.2: + resolution: {integrity: sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + requiresBuild: true + dev: true + optional: true + /fsevents@2.3.3: resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} @@ -4940,6 +4959,22 @@ packages: hasBin: true dev: true + /playwright-core@1.61.0: + resolution: {integrity: sha512-caX7TrY3Ml6egyDX0WUcTHDxodl/b51y5wJOdCEA36QviK/s2g081hvmGs8eaE3DWb6NYZQ6BjO/QkNRPenoPA==} + engines: {node: '>=18'} + hasBin: true + dev: true + + /playwright@1.61.0: + resolution: {integrity: sha512-Z+7BeeqQPRRzklHsVFP4KTGIyMxKUmfeRA4WisM6G3/XW6nwGeX6fX9qYaDa+CiUqpOkb2f6X3nar05R3kSuJQ==} + engines: {node: '>=18'} + hasBin: true + dependencies: + playwright-core: 1.61.0 + optionalDependencies: + fsevents: 2.3.2 + dev: true + /pngjs@5.0.0: resolution: {integrity: sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==} engines: {node: '>=10.13.0'} diff --git a/web/tests/e2e/fixtures/langbot-api.ts b/web/tests/e2e/fixtures/langbot-api.ts new file mode 100644 index 000000000..08f23a5bb --- /dev/null +++ b/web/tests/e2e/fixtures/langbot-api.ts @@ -0,0 +1,417 @@ +import { Page, Route } from '@playwright/test'; + +type JsonRecord = Record; + +interface SkillMock { + name: string; + display_name: string; + description: string; + instructions: string; + package_root: string; + updated_at: string; +} + +interface LangBotApiMockState { + skills: SkillMock[]; +} + +function ok(data: unknown) { + return { + code: 0, + message: 'ok', + data, + timestamp: Date.now(), + }; +} + +async function fulfillJson(route: Route, data: unknown) { + await route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify(ok(data)), + }); +} + +function routePath(route: Route) { + return new URL(route.request().url()).pathname; +} + +function emptyMonitoringData() { + return { + overview: { + total_messages: 0, + llm_calls: 0, + embedding_calls: 0, + model_calls: 0, + success_rate: 0, + active_sessions: 0, + }, + messages: [], + llmCalls: [], + embeddingCalls: [], + sessions: [], + errors: [], + totalCount: { + messages: 0, + llmCalls: 0, + embeddingCalls: 0, + sessions: 0, + errors: 0, + }, + }; +} + +function emptyTokenStatistics() { + return { + summary: { + total_calls: 0, + success_calls: 0, + error_calls: 0, + total_input_tokens: 0, + total_output_tokens: 0, + total_tokens: 0, + total_cost: 0, + avg_tokens_per_call: 0, + avg_duration_ms: 0, + avg_tokens_per_second: 0, + zero_token_success_calls: 0, + }, + by_model: [], + timeseries: [], + bucket: 'day', + }; +} + +function makeSkill(data: JsonRecord): SkillMock { + return { + name: String(data.name || ''), + display_name: String(data.display_name || ''), + description: String(data.description || ''), + instructions: String(data.instructions || ''), + package_root: String(data.package_root || ''), + updated_at: new Date().toISOString(), + }; +} + +async function handleBackendApi(route: Route, state: LangBotApiMockState) { + const request = route.request(); + const url = new URL(request.url()); + const path = url.pathname; + const method = request.method(); + + if (path === '/api/v1/system/info') { + return fulfillJson(route, { + debug: false, + version: 'frontend-smoke', + edition: 'community', + cloud_service_url: 'https://space.langbot.app', + enable_marketplace: true, + allow_modify_login_info: true, + disable_models_service: false, + limitation: { + max_bots: -1, + max_pipelines: -1, + max_extensions: -1, + }, + outbound_ips: [], + wizard_status: 'completed', + wizard_progress: null, + }); + } + + if (path === '/api/v1/user/account-info') { + return fulfillJson(route, { + initialized: true, + account_type: 'local', + has_password: true, + }); + } + + if (path === '/api/v1/user/check-token') { + return fulfillJson(route, { token: '' }); + } + + if (path === '/api/v1/user/auth') { + return fulfillJson(route, { token: 'playwright-token' }); + } + + if (path === '/api/v1/user/info') { + return fulfillJson(route, { + user: 'admin@example.com', + account_type: 'local', + has_password: true, + }); + } + + if (path === '/api/v1/user/space-credits') { + return fulfillJson(route, { credits: null }); + } + + if (path === '/api/v1/platform/bots') { + return fulfillJson(route, { bots: [] }); + } + + if (path === '/api/v1/pipelines') { + return fulfillJson(route, { pipelines: [] }); + } + + if (path === '/api/v1/knowledge/bases') { + return fulfillJson(route, { bases: [] }); + } + + if (path === '/api/v1/knowledge/migration/status') { + return fulfillJson(route, { + needed: false, + internal_kb_count: 0, + external_kb_count: 0, + }); + } + + if (path === '/api/v1/plugins') { + return fulfillJson(route, { plugins: [] }); + } + + if (path === '/api/v1/extensions') { + return fulfillJson(route, { extensions: [] }); + } + + if (path === '/api/v1/mcp/servers') { + return fulfillJson(route, { servers: [] }); + } + + if (path === '/api/v1/skills') { + if (method === 'POST') { + const skill = makeSkill( + JSON.parse(request.postData() || '{}') as JsonRecord, + ); + state.skills = [ + ...state.skills.filter((item) => item.name !== skill.name), + skill, + ]; + return fulfillJson(route, { skill }); + } + + return fulfillJson(route, { skills: state.skills }); + } + + const skillFileMatch = path.match( + /^\/api\/v1\/skills\/([^/]+)\/files\/(.+)$/, + ); + if (skillFileMatch) { + const skillName = decodeURIComponent(skillFileMatch[1]); + const filePath = decodeURIComponent(skillFileMatch[2]); + const skill = state.skills.find((item) => item.name === skillName); + return fulfillJson(route, { + skill: { name: skillName }, + path: filePath, + content: skill?.instructions || '', + }); + } + + const skillFilesMatch = path.match(/^\/api\/v1\/skills\/([^/]+)\/files$/); + if (skillFilesMatch) { + const skillName = decodeURIComponent(skillFilesMatch[1]); + return fulfillJson(route, { + skill: { name: skillName }, + base_path: '.', + entries: [ + { + path: 'SKILL.md', + name: 'SKILL.md', + is_dir: false, + size: null, + }, + ], + truncated: false, + }); + } + + const skillMatch = path.match(/^\/api\/v1\/skills\/([^/]+)$/); + if (skillMatch) { + const skillName = decodeURIComponent(skillMatch[1]); + const skill = state.skills.find((item) => item.name === skillName) || { + name: skillName, + display_name: '', + description: '', + instructions: '', + package_root: '', + updated_at: new Date().toISOString(), + }; + return fulfillJson(route, { skill }); + } + + if (path === '/api/v1/system/status/plugin-system') { + return fulfillJson(route, { + is_enable: true, + is_connected: true, + plugin_connector_error: '', + }); + } + + if (path === '/api/v1/plugins/debug-info') { + return fulfillJson(route, { + debug_url: 'ws://127.0.0.1:5300/plugin/debug', + plugin_debug_key: 'test-debug-key', + }); + } + + if (path === '/api/v1/box/status') { + return fulfillJson(route, { + available: true, + enabled: true, + profile: 'playwright', + recent_error_count: 0, + active_sessions: 0, + managed_processes: 0, + session_ttl_sec: 3600, + backend: { + name: 'playwright', + available: true, + }, + }); + } + + if (path === '/api/v1/box/sessions') { + return fulfillJson(route, []); + } + + if (path === '/api/v1/monitoring/data') { + return fulfillJson(route, emptyMonitoringData()); + } + + if (path === '/api/v1/monitoring/overview') { + return fulfillJson(route, emptyMonitoringData().overview); + } + + if (path === '/api/v1/monitoring/token-statistics') { + return fulfillJson(route, emptyTokenStatistics()); + } + + if (path === '/api/v1/monitoring/feedback/stats') { + return fulfillJson(route, { + total_feedback: 0, + total_likes: 0, + total_dislikes: 0, + satisfaction_rate: 0, + }); + } + + if (path === '/api/v1/monitoring/feedback') { + return fulfillJson(route, { feedback: [], total: 0 }); + } + + if (path === '/api/v1/survey/pending') { + return fulfillJson(route, { survey: null }); + } + + if (path === '/api/v1/system/tasks') { + return fulfillJson(route, { tasks: [] }); + } + + if ( + path === '/api/v1/marketplace/plugins' || + path === '/api/v1/marketplace/plugins/search' || + path === '/api/v1/marketplace/extensions/search' || + path === '/api/v1/marketplace/mcps/search' || + path === '/api/v1/marketplace/skills/search' + ) { + return fulfillJson(route, { plugins: [], total: 0 }); + } + + if (path === '/api/v1/marketplace/tags') { + return fulfillJson(route, { tags: [] }); + } + + if (path === '/api/v1/marketplace/recommendation-lists') { + return fulfillJson(route, { lists: [] }); + } + + if (path === '/api/v1/dist/info/releases') { + return fulfillJson(route, []); + } + + if (path === '/api/v1/dist/info/repo') { + return fulfillJson(route, { + repo: { + stargazers_count: 0, + forks_count: 0, + open_issues_count: 0, + }, + contributors: [], + }); + } + + await fulfillJson(route, {}); +} + +async function handleCloudApi(route: Route) { + const path = routePath(route); + + if ( + path === '/api/v1/marketplace/plugins' || + path === '/api/v1/marketplace/plugins/search' || + path === '/api/v1/marketplace/extensions/search' || + path === '/api/v1/marketplace/mcps/search' || + path === '/api/v1/marketplace/skills/search' + ) { + return fulfillJson(route, { plugins: [], total: 0 }); + } + + if (path === '/api/v1/marketplace/tags') { + return fulfillJson(route, { tags: [] }); + } + + if (path === '/api/v1/marketplace/recommendation-lists') { + return fulfillJson(route, { lists: [] }); + } + + if (path === '/api/v1/dist/info/releases') { + return fulfillJson(route, []); + } + + if (path === '/api/v1/dist/info/repo') { + return fulfillJson(route, { + repo: { + stargazers_count: 0, + forks_count: 0, + open_issues_count: 0, + }, + contributors: [], + }); + } + + await fulfillJson(route, {}); +} + +export async function installLangBotApiMocks( + page: Page, + options: { authenticated?: boolean; storage?: JsonRecord } = {}, +) { + const { authenticated = false, storage = {} } = options; + const state: LangBotApiMockState = { + skills: [], + }; + + await page.addInitScript( + ({ authenticated, storage }) => { + localStorage.setItem('langbot_language', 'en-US'); + localStorage.setItem('extensions_group_by_type', 'false'); + + if (authenticated) { + localStorage.setItem('token', 'playwright-token'); + localStorage.setItem('userEmail', 'admin@example.com'); + } else { + localStorage.removeItem('token'); + localStorage.removeItem('userEmail'); + } + + for (const [key, value] of Object.entries(storage)) { + localStorage.setItem(key, String(value)); + } + }, + { authenticated, storage }, + ); + + await page.route('**/api/v1/**', (route) => handleBackendApi(route, state)); + await page.route('https://space.langbot.app/**', handleCloudApi); +} diff --git a/web/tests/e2e/home-smoke.spec.ts b/web/tests/e2e/home-smoke.spec.ts new file mode 100644 index 000000000..32623864b --- /dev/null +++ b/web/tests/e2e/home-smoke.spec.ts @@ -0,0 +1,133 @@ +import { expect, test } from '@playwright/test'; + +import { installLangBotApiMocks } from './fixtures/langbot-api'; + +const appRoutes = [ + { + path: '/home/bots', + heading: 'Bots', + bodyText: 'Select a bot from the sidebar', + }, + { + path: '/home/pipelines', + heading: 'Pipelines', + bodyText: 'Select a pipeline from the sidebar', + }, + { + path: '/home/extensions', + heading: 'Extensions', + bodyText: 'No extensions installed', + }, + { + path: '/home/mcp', + heading: 'MCP', + bodyText: 'Select an MCP server from the sidebar', + }, + { + path: '/home/knowledge', + heading: 'Knowledge', + bodyText: 'Select a knowledge base from the sidebar', + }, +]; + +test.describe('authenticated app shell', () => { + for (const route of appRoutes) { + test(`${route.path} renders without a backend process`, async ({ + page, + }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto(route.path); + + await expect(page).toHaveURL(new RegExp(`${route.path}$`)); + await expect(page.getByText('Home').first()).toBeVisible(); + await expect( + page.getByRole('button', { name: 'Dashboard' }), + ).toBeVisible(); + await expect(page.getByText('Extensions').first()).toBeVisible(); + await expect(page.getByText(route.heading).first()).toBeVisible(); + await expect(page.getByText(route.bodyText)).toBeVisible(); + await expect(page.getByText('Backend unavailable')).toHaveCount(0); + }); + } + + test('/home/monitoring loads dashboard data from mocked APIs', async ({ + page, + }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/monitoring'); + + await expect(page).toHaveURL(/\/home\/monitoring$/); + await expect(page.getByText('Total Messages').first()).toBeVisible(); + await expect( + page.getByRole('tab', { name: 'Message Records' }), + ).toBeVisible(); + await expect( + page.getByRole('tab', { name: 'Token Monitoring' }), + ).toBeVisible(); + + await page.getByRole('tab', { name: 'Token Monitoring' }).click(); + await expect( + page.getByText('No token usage in the selected time range'), + ).toBeVisible(); + await expect(page.getByText('Unable to connect to server')).toHaveCount(0); + }); + + test('/home/extensions shows plugin debug information from the backend', async ({ + page, + }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/extensions'); + + await page.getByRole('button', { name: 'Debug Info' }).click(); + + await expect(page.getByText('Plugin Debug Information')).toBeVisible(); + await expect(page.getByRole('textbox').nth(0)).toHaveValue( + 'ws://127.0.0.1:5300/plugin/debug', + ); + await expect(page.getByRole('textbox').nth(1)).toHaveValue( + 'test-debug-key', + ); + }); + + test('/home/skills?action=create creates a manual skill', async ({ + page, + }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/skills?action=create'); + + await expect(page).toHaveURL(/\/home\/skills\?action=create$/); + await expect(page.getByText('Create Skill').first()).toBeVisible(); + await expect(page.getByText('Import Local Skill Directory')).toBeVisible(); + + const saveButton = page.getByRole('button', { name: 'Save' }); + await expect(saveButton).toBeEnabled(); + await saveButton.click(); + await expect(page.getByText('Skill name cannot be empty')).toBeVisible(); + + await page.locator('#display_name').fill('Daily Summary'); + await page.locator('#name').fill('daily_summary'); + await page + .locator('#description') + .fill('Summarizes the current conversation for handoff.'); + await page + .locator('#instructions') + .fill('Summarize the conversation in five concise bullet points.'); + await saveButton.click(); + + await expect(page).toHaveURL(/\/home\/skills\?id=daily_summary$/); + await expect( + page.getByRole('heading', { name: 'Daily Summary' }), + ).toBeVisible(); + await expect(page.locator('#name')).toHaveValue('daily_summary'); + await expect(page.locator('#description')).toHaveValue( + 'Summarizes the current conversation for handoff.', + ); + await expect(page.locator('#instructions')).toHaveValue( + 'Summarize the conversation in five concise bullet points.', + ); + }); +}); diff --git a/web/tests/e2e/login.spec.ts b/web/tests/e2e/login.spec.ts new file mode 100644 index 000000000..ae9735a65 --- /dev/null +++ b/web/tests/e2e/login.spec.ts @@ -0,0 +1,22 @@ +import { expect, test } from '@playwright/test'; + +import { installLangBotApiMocks } from './fixtures/langbot-api'; + +test('local account login reaches the authenticated home shell', async ({ + page, +}) => { + await installLangBotApiMocks(page); + + await page.goto('/login'); + + await expect(page.getByText('Welcome')).toBeVisible(); + await page.getByPlaceholder('Enter email address').fill('admin@example.com'); + await page.getByPlaceholder('Enter password').fill('password'); + await page.getByRole('button', { name: 'Login with password' }).click(); + + await expect(page).toHaveURL(/\/home$/); + await expect(page.getByText('Home').first()).toBeVisible(); + await expect(page.getByRole('button', { name: 'Dashboard' })).toBeVisible(); + await expect(page.getByText('Total Messages').first()).toBeVisible(); + await expect(page.getByText('Unable to connect to server')).toHaveCount(0); +}); From f390980d0a1f89d22808b91933d0d9576d272f25 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Tue, 16 Jun 2026 03:22:29 +0000 Subject: [PATCH 02/16] test: format test suite (#2252) --- tests/e2e/test_startup.py | 13 +- tests/e2e/utils/config_factory.py | 2 +- tests/e2e/utils/process_manager.py | 6 +- tests/factories/__init__.py | 72 +++---- tests/factories/message.py | 160 ++++++++-------- tests/factories/platform.py | 98 +++++----- tests/factories/provider.py | 80 ++++---- tests/integration/__init__.py | 2 +- tests/integration/api/__init__.py | 2 +- tests/integration/api/test_bots.py | 67 ++++--- tests/integration/api/test_embed.py | 53 +++--- tests/integration/api/test_knowledge.py | 74 ++++---- tests/integration/api/test_monitoring.py | 116 +++++------- tests/integration/api/test_providers.py | 74 ++++---- tests/integration/api/test_smoke.py | 26 +-- tests/integration/persistence/__init__.py | 2 +- .../persistence/test_migrations.py | 36 ++-- .../persistence/test_migrations_postgres.py | 34 ++-- tests/integration/pipeline/__init__.py | 2 +- tests/integration/pipeline/test_full_flow.py | 62 ++++--- tests/smoke/__init__.py | 2 +- tests/smoke/test_fake_message_flow.py | 131 ++++++------- tests/test_cwe94_debug_exec.py | 31 +--- tests/unit_tests/api/__init__.py | 2 +- tests/unit_tests/api/service/__init__.py | 2 +- .../api/service/test_apikey_service.py | 4 +- .../api/service/test_bot_service.py | 32 +--- .../api/service/test_knowledge_service.py | 40 ++-- .../api/service/test_maintenance_service.py | 16 +- .../api/service/test_mcp_service.py | 26 ++- .../api/service/test_model_service.py | 175 ++++++++++-------- .../api/service/test_pipeline_service.py | 79 ++++---- .../api/service/test_provider_service.py | 56 +++--- .../api/service/test_space_service.py | 100 +++++----- .../api/service/test_user_service.py | 4 +- .../api/service/test_webhook_service.py | 9 +- tests/unit_tests/box/test_box_service.py | 8 +- tests/unit_tests/command/__init__.py | 2 +- tests/unit_tests/command/test_cmdmgr.py | 2 +- tests/unit_tests/command/test_operator.py | 3 +- tests/unit_tests/config/test_config_loader.py | 54 +++--- tests/unit_tests/core/__init__.py | 2 +- .../core/test_app_config_validation.py | 3 +- tests/unit_tests/core/test_bootutils_deps.py | 5 + tests/unit_tests/core/test_load_config.py | 14 +- tests/unit_tests/core/test_stage.py | 2 +- tests/unit_tests/core/test_taskmgr.py | 46 ++--- tests/unit_tests/discover/test_engine.py | 82 ++++---- .../persistence/test_database_decorator.py | 4 +- .../persistence/test_mgr_methods.py | 5 +- .../persistence/test_serialize_model.py | 1 + tests/unit_tests/pipeline/test_aggregator.py | 28 +-- .../unit_tests/pipeline/test_chat_handler.py | 38 +++- tests/unit_tests/pipeline/test_cntfilter.py | 46 +++-- .../pipeline/test_command_handler.py | 13 +- tests/unit_tests/pipeline/test_longtext.py | 46 ++--- tests/unit_tests/pipeline/test_msgtrun.py | 14 +- tests/unit_tests/pipeline/test_preproc.py | 27 +-- tests/unit_tests/pipeline/test_ratelimit.py | 76 ++------ tests/unit_tests/pipeline/test_wrapper.py | 52 +++--- .../plugin/test_connector_methods.py | 50 ++--- .../plugin/test_connector_static.py | 1 + tests/unit_tests/plugin/test_extract_deps.py | 9 +- tests/unit_tests/plugin/test_handler.py | 57 +++--- .../unit_tests/plugin/test_handler_actions.py | 96 +++++----- .../unit_tests/plugin/test_handler_helpers.py | 15 +- tests/unit_tests/provider/conftest.py | 11 +- .../provider/runners/test_difysvapi_runner.py | 10 +- tests/unit_tests/provider/test_litellmchat.py | 4 +- .../unit_tests/provider/test_model_manager.py | 12 +- .../provider/test_requester_base.py | 13 +- .../provider/test_session_manager.py | 64 ++----- .../unit_tests/provider/test_tool_manager.py | 1 + tests/unit_tests/rag/test_i18n_conversion.py | 3 +- tests/unit_tests/rag/test_kbmgr.py | 51 ++--- tests/unit_tests/rag/test_runtime_service.py | 15 +- .../test_localstorage_path_traversal.py | 110 +++++------ tests/unit_tests/storage/test_s3storage.py | 5 +- .../storage/test_storage_manager.py | 32 ++-- tests/unit_tests/telemetry/test_telemetry.py | 16 +- tests/unit_tests/utils/test_funcschema.py | 1 + tests/unit_tests/utils/test_image.py | 62 +++---- tests/unit_tests/utils/test_importutil.py | 92 ++++----- tests/unit_tests/utils/test_platform.py | 3 +- tests/unit_tests/utils/test_proxy.py | 28 +-- tests/unit_tests/utils/test_runner.py | 157 +++++++--------- tests/unit_tests/vector/test_filter_utils.py | 17 +- tests/unit_tests/vector/test_mgr.py | 58 ++---- tests/unit_tests/vector/test_vdb_base.py | 60 ++++-- .../vector/test_vdb_filter_conversion.py | 50 +++-- tests/utils/__init__.py | 2 +- tests/utils/import_isolation.py | 3 +- 92 files changed, 1658 insertions(+), 1713 deletions(-) diff --git a/tests/e2e/test_startup.py b/tests/e2e/test_startup.py index 8954505a2..e63150b4c 100644 --- a/tests/e2e/test_startup.py +++ b/tests/e2e/test_startup.py @@ -44,6 +44,7 @@ class TestStartupFlow: # Database should have some tables after migration import sqlite3 + conn = sqlite3.connect(str(e2e_db_path)) cursor = conn.cursor() @@ -74,10 +75,13 @@ class TestStartupFlow: def test_auth_endpoint(self, e2e_client, e2e_tmpdir): """Test auth endpoint.""" # First startup may allow initial setup - response = e2e_client.post('/api/v1/user/auth', json={ - 'user': 'admin', - 'password': 'admin', - }) + response = e2e_client.post( + '/api/v1/user/auth', + json={ + 'user': 'admin', + 'password': 'admin', + }, + ) # Response could be: # - 200 if auth succeeds @@ -97,6 +101,7 @@ class TestStartupStages: def test_migrations_applied(self, langbot_process, e2e_db_path): """Verify database migrations were applied.""" import sqlite3 + conn = sqlite3.connect(str(e2e_db_path)) cursor = conn.cursor() diff --git a/tests/e2e/utils/config_factory.py b/tests/e2e/utils/config_factory.py index b838827cb..b2bc2f7d4 100644 --- a/tests/e2e/utils/config_factory.py +++ b/tests/e2e/utils/config_factory.py @@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]: for path in directories.values(): path.mkdir(parents=True, exist_ok=True) - return directories \ No newline at end of file + return directories diff --git a/tests/e2e/utils/process_manager.py b/tests/e2e/utils/process_manager.py index 840509874..44c6719e5 100644 --- a/tests/e2e/utils/process_manager.py +++ b/tests/e2e/utils/process_manager.py @@ -90,9 +90,11 @@ precision = 2 f.write(coveragerc_content) cmd = [ - 'coverage', 'run', + 'coverage', + 'run', '--rcfile=' + str(coveragerc_path), - '-m', 'langbot', + '-m', + 'langbot', ] else: cmd = ['uv', 'run', 'python', '-m', 'langbot'] diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 3a6e3d984..a6564c849 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -58,45 +58,45 @@ from tests.factories.platform import ( __all__ = [ # App - "FakeApp", - "fake_app", + 'FakeApp', + 'fake_app', # Message chains - "text_chain", - "group_text_chain", - "mention_chain", - "image_chain", + 'text_chain', + 'group_text_chain', + 'mention_chain', + 'image_chain', # Message events - "friend_message_event", - "group_message_event", + 'friend_message_event', + 'group_message_event', # Mock adapters - "mock_adapter", + 'mock_adapter', # Queries - "text_query", - "group_text_query", - "private_text_query", - "command_query", - "mention_query", - "empty_query", - "image_query", - "file_query", - "unsupported_query", - "voice_query", - "at_all_query", - "query_with_session", - "query_with_config", + 'text_query', + 'group_text_query', + 'private_text_query', + 'command_query', + 'mention_query', + 'empty_query', + 'image_query', + 'file_query', + 'unsupported_query', + 'voice_query', + 'at_all_query', + 'query_with_session', + 'query_with_config', # Provider - "FakeProvider", - "fake_provider", - "fake_provider_pong", - "fake_provider_timeout", - "fake_provider_auth_error", - "fake_provider_rate_limit", - "fake_provider_malformed", - "fake_model", + 'FakeProvider', + 'fake_provider', + 'fake_provider_pong', + 'fake_provider_timeout', + 'fake_provider_auth_error', + 'fake_provider_rate_limit', + 'fake_provider_malformed', + 'fake_model', # Platform - "FakePlatform", - "fake_platform", - "fake_platform_with_streaming", - "fake_platform_with_failure", - "mock_platform_adapter", -] \ No newline at end of file + 'FakePlatform', + 'fake_platform', + 'fake_platform_with_streaming', + 'fake_platform_with_failure', + 'mock_platform_adapter', +] diff --git a/tests/factories/message.py b/tests/factories/message.py index 8871c664a..9b3cc3602 100644 --- a/tests/factories/message.py +++ b/tests/factories/message.py @@ -30,32 +30,36 @@ def _next_query_id() -> int: # ============== Message Chain Factories ============== -def text_chain(text: str = "hello") -> platform_message.MessageChain: +def text_chain(text: str = 'hello') -> platform_message.MessageChain: """Create a simple text message chain.""" - return platform_message.MessageChain([ - platform_message.Plain(text=text), - ]) + return platform_message.MessageChain( + [ + platform_message.Plain(text=text), + ] + ) -def group_text_chain(text: str = "hello") -> platform_message.MessageChain: +def group_text_chain(text: str = 'hello') -> platform_message.MessageChain: """Create a group text message chain (same as text_chain, context provided by event).""" return text_chain(text) def mention_chain( - text: str = "hello", + text: str = 'hello', target: typing.Union[int, str] = 12345, ) -> platform_message.MessageChain: """Create a message chain with @mention.""" - return platform_message.MessageChain([ - platform_message.At(target=target), - platform_message.Plain(text=f" {text}"), - ]) + return platform_message.MessageChain( + [ + platform_message.At(target=target), + platform_message.Plain(text=f' {text}'), + ] + ) def image_chain( - text: str = "", - url: str = "https://example.com/image.png", + text: str = '', + url: str = 'https://example.com/image.png', ) -> platform_message.MessageChain: """Create a message chain with an image.""" components = [] @@ -66,13 +70,15 @@ def image_chain( def command_chain( - command: str = "help", - prefix: str = "/", + command: str = 'help', + prefix: str = '/', ) -> platform_message.MessageChain: """Create a command message chain.""" - return platform_message.MessageChain([ - platform_message.Plain(text=f"{prefix}{command}"), - ]) + return platform_message.MessageChain( + [ + platform_message.Plain(text=f'{prefix}{command}'), + ] + ) # ============== Message Event Factories ============== @@ -81,7 +87,7 @@ def command_chain( def friend_message_event( message_chain: platform_message.MessageChain, sender_id: typing.Union[int, str] = 12345, - nickname: str = "TestUser", + nickname: str = 'TestUser', ) -> platform_events.FriendMessage: """Create a friend (private) message event.""" sender = platform_entities.Friend( @@ -90,7 +96,7 @@ def friend_message_event( remark=None, ) return platform_events.FriendMessage( - type="FriendMessage", + type='FriendMessage', sender=sender, message_chain=message_chain, time=1609459200, @@ -100,9 +106,9 @@ def friend_message_event( def group_message_event( message_chain: platform_message.MessageChain, sender_id: typing.Union[int, str] = 12345, - sender_name: str = "TestUser", + sender_name: str = 'TestUser', group_id: typing.Union[int, str] = 99999, - group_name: str = "TestGroup", + group_name: str = 'TestGroup', ) -> platform_events.GroupMessage: """Create a group message event.""" group = platform_entities.Group( @@ -117,7 +123,7 @@ def group_message_event( group=group, ) return platform_events.GroupMessage( - type="GroupMessage", + type='GroupMessage', sender=sender, message_chain=message_chain, time=1609459200, @@ -152,36 +158,36 @@ def _base_query( query_id = _next_query_id() base_data = { - "query_id": query_id, - "launcher_type": launcher_type, - "launcher_id": launcher_id, - "sender_id": sender_id, - "message_chain": message_chain, - "message_event": message_event, - "adapter": adapter, - "pipeline_uuid": "test-pipeline-uuid", - "bot_uuid": "test-bot-uuid", - "pipeline_config": { - "ai": { - "runner": {"runner": "local-agent"}, - "local-agent": { - "model": {"primary": "test-model-uuid", "fallbacks": []}, - "prompt": "test-prompt", + 'query_id': query_id, + 'launcher_type': launcher_type, + 'launcher_id': launcher_id, + 'sender_id': sender_id, + 'message_chain': message_chain, + 'message_event': message_event, + 'adapter': adapter, + 'pipeline_uuid': 'test-pipeline-uuid', + 'bot_uuid': 'test-bot-uuid', + 'pipeline_config': { + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': { + 'model': {'primary': 'test-model-uuid', 'fallbacks': []}, + 'prompt': 'test-prompt', }, }, - "output": {"misc": {"at-sender": False, "quote-origin": False}}, - "trigger": {"misc": {"combine-quote-message": False}}, + 'output': {'misc': {'at-sender': False, 'quote-origin': False}}, + 'trigger': {'misc': {'combine-quote-message': False}}, }, - "session": None, - "prompt": None, - "messages": [], - "user_message": None, - "use_funcs": [], - "use_llm_model_uuid": None, - "variables": {}, - "resp_messages": [], - "resp_message_chain": None, - "current_stage_name": None, + 'session': None, + 'prompt': None, + 'messages': [], + 'user_message': None, + 'use_funcs': [], + 'use_llm_model_uuid': None, + 'variables': {}, + 'resp_messages': [], + 'resp_message_chain': None, + 'current_stage_name': None, } # Apply overrides @@ -192,7 +198,7 @@ def _base_query( def text_query( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -212,7 +218,7 @@ def text_query( def private_text_query( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -221,7 +227,7 @@ def private_text_query( def group_text_query( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, **overrides, @@ -242,8 +248,8 @@ def group_text_query( def command_query( - command: str = "help", - prefix: str = "/", + command: str = 'help', + prefix: str = '/', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -263,7 +269,7 @@ def command_query( def mention_query( - text: str = "hello", + text: str = 'hello', target: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, @@ -301,8 +307,8 @@ def empty_query(**overrides) -> pipeline_query.Query: def image_query( - text: str = "", - url: str = "https://example.com/image.png", + text: str = '', + url: str = 'https://example.com/image.png', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -322,9 +328,9 @@ def image_query( def file_query( - url: str = "https://example.com/document.pdf", - name: str = "document.pdf", - text: str = "", + url: str = 'https://example.com/document.pdf', + name: str = 'document.pdf', + text: str = '', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -348,8 +354,8 @@ def file_query( def unsupported_query( - unsupported_type: str = "CustomComponent", - text: str = "", + unsupported_type: str = 'CustomComponent', + text: str = '', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -358,7 +364,7 @@ def unsupported_query( if text: components.append(platform_message.Plain(text=text)) # Use Unknown component for unsupported types - components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}")) + components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}')) chain = platform_message.MessageChain(components) event = friend_message_event(chain, sender_id) adapter = mock_adapter() @@ -374,7 +380,7 @@ def unsupported_query( def query_with_session( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, session: provider_session.Session = None, **overrides, @@ -389,7 +395,7 @@ def query_with_session( launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, - use_prompt_name="default", + use_prompt_name='default', using_conversation=None, conversations=[], ) @@ -398,7 +404,7 @@ def query_with_session( def query_with_config( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, pipeline_config: dict = None, **overrides, @@ -410,22 +416,22 @@ def query_with_config( """ if pipeline_config is None: pipeline_config = { - "ai": { - "runner": {"runner": "local-agent"}, - "local-agent": { - "model": {"primary": "test-model-uuid", "fallbacks": []}, - "prompt": "test-prompt", + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': { + 'model': {'primary': 'test-model-uuid', 'fallbacks': []}, + 'prompt': 'test-prompt', }, }, - "output": {"misc": {"at-sender": False, "quote-origin": False}}, - "trigger": {"misc": {"combine-quote-message": False}}, + 'output': {'misc': {'at-sender': False, 'quote-origin': False}}, + 'trigger': {'misc': {'combine-quote-message': False}}, } return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides) def voice_query( - url: str = "https://example.com/audio.mp3", + url: str = 'https://example.com/audio.mp3', sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: @@ -448,7 +454,7 @@ def voice_query( def at_all_query( - text: str = "hello", + text: str = 'hello', sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, **overrides, @@ -456,7 +462,7 @@ def at_all_query( """Create a group query with @All mention.""" components = [ platform_message.AtAll(), - platform_message.Plain(text=f" {text}"), + platform_message.Plain(text=f' {text}'), ] chain = platform_message.MessageChain(components) event = group_message_event(chain, sender_id, group_id=group_id) @@ -469,4 +475,4 @@ def at_all_query( sender_id=sender_id, adapter=adapter, **overrides, - ) \ No newline at end of file + ) diff --git a/tests/factories/platform.py b/tests/factories/platform.py index 725cead91..77b8f11f9 100644 --- a/tests/factories/platform.py +++ b/tests/factories/platform.py @@ -33,7 +33,7 @@ class FakePlatform: def __init__( self, *, - bot_account_id: str = "test-bot", + bot_account_id: str = 'test-bot', stream_output_supported: bool = False, raise_error: Exception = None, ): @@ -48,16 +48,16 @@ class FakePlatform: # Registered listeners self._listeners: dict = {} - def raises(self, error: Exception) -> "FakePlatform": + def raises(self, error: Exception) -> 'FakePlatform': """Configure platform to raise an error on send.""" self._raise_error = error return self - def send_failure(self) -> "FakePlatform": + def send_failure(self) -> 'FakePlatform': """Configure platform to simulate send failure.""" - return self.raises(Exception("Platform send failure")) + return self.raises(Exception('Platform send failure')) - def supports_streaming(self, supported: bool = True) -> "FakePlatform": + def supports_streaming(self, supported: bool = True) -> 'FakePlatform': """Configure whether streaming output is supported.""" self._stream_output_supported = supported return self @@ -89,7 +89,7 @@ class FakePlatform: self, text: str, sender_id: typing.Union[int, str] = 12345, - nickname: str = "TestUser", + nickname: str = 'TestUser', ) -> platform_events.FriendMessage: """Create an inbound friend (private) message event.""" sender = platform_entities.Friend( @@ -97,11 +97,13 @@ class FakePlatform: nickname=nickname, remark=None, ) - chain = platform_message.MessageChain([ - platform_message.Plain(text=text), - ]) + chain = platform_message.MessageChain( + [ + platform_message.Plain(text=text), + ] + ) return platform_events.FriendMessage( - type="FriendMessage", + type='FriendMessage', sender=sender, message_chain=chain, time=1609459200, @@ -111,9 +113,9 @@ class FakePlatform: self, text: str, sender_id: typing.Union[int, str] = 12345, - sender_name: str = "TestUser", + sender_name: str = 'TestUser', group_id: typing.Union[int, str] = 99999, - group_name: str = "TestGroup", + group_name: str = 'TestGroup', mention_bot: bool = False, ) -> platform_events.GroupMessage: """Create an inbound group message event. @@ -142,12 +144,12 @@ class FakePlatform: components = [] if mention_bot: components.append(platform_message.At(target=self.bot_account_id)) - components.append(platform_message.Plain(text=" ")) + components.append(platform_message.Plain(text=' ')) components.append(platform_message.Plain(text=text)) chain = platform_message.MessageChain(components) return platform_events.GroupMessage( - type="GroupMessage", + type='GroupMessage', sender=sender, message_chain=chain, time=1609459200, @@ -155,8 +157,8 @@ class FakePlatform: def create_image_message( self, - url: str = "https://example.com/image.png", - text: str = "", + url: str = 'https://example.com/image.png', + text: str = '', sender_id: typing.Union[int, str] = 12345, is_group: bool = False, group_id: typing.Union[int, str] = 99999, @@ -169,12 +171,12 @@ class FakePlatform: chain = platform_message.MessageChain(components) if is_group: - return self.create_group_message("", sender_id, group_id=group_id) + return self.create_group_message('', sender_id, group_id=group_id) # Replace chain else: - sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None) + sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None) return platform_events.FriendMessage( - type="FriendMessage", + type='FriendMessage', sender=sender, message_chain=chain, time=1609459200, @@ -192,12 +194,14 @@ class FakePlatform: if self._raise_error: raise self._raise_error - self._outbound_messages.append({ - "type": "send", - "target_type": target_type, - "target_id": target_id, - "message": message, - }) + self._outbound_messages.append( + { + 'type': 'send', + 'target_type': target_type, + 'target_id': target_id, + 'message': message, + } + ) async def reply_message( self, @@ -209,13 +213,15 @@ class FakePlatform: if self._raise_error: raise self._raise_error - self._outbound_messages.append({ - "type": "reply", - "source_type": message_source.type, - "source": message_source, - "message": message, - "quote_origin": quote_origin, - }) + self._outbound_messages.append( + { + 'type': 'reply', + 'source_type': message_source.type, + 'source': message_source, + 'message': message, + 'quote_origin': quote_origin, + } + ) async def reply_message_chunk( self, @@ -229,15 +235,17 @@ class FakePlatform: if self._raise_error: raise self._raise_error - self._outbound_chunks.append({ - "type": "reply_chunk", - "source_type": message_source.type, - "source": message_source, - "bot_message": bot_message, - "message": message, - "quote_origin": quote_origin, - "is_final": is_final, - }) + self._outbound_chunks.append( + { + 'type': 'reply_chunk', + 'source_type': message_source.type, + 'source': message_source, + 'bot_message': bot_message, + 'message': message, + 'quote_origin': quote_origin, + 'is_final': is_final, + } + ) async def is_stream_output_supported(self) -> bool: """Return whether streaming output is supported.""" @@ -295,7 +303,7 @@ class FakePlatform: def fake_platform( - bot_account_id: str = "test-bot", + bot_account_id: str = 'test-bot', stream_output_supported: bool = False, ) -> FakePlatform: """Create a FakePlatform instance.""" @@ -328,9 +336,7 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock: adapter.reply_message = AsyncMock(side_effect=platform.reply_message) adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk) adapter.send_message = AsyncMock(side_effect=platform.send_message) - adapter.is_stream_output_supported = AsyncMock( - return_value=platform._stream_output_supported - ) + adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported) adapter._fake_platform = platform # Store for assertions - return adapter \ No newline at end of file + return adapter diff --git a/tests/factories/provider.py b/tests/factories/provider.py index d50978549..a7f9d1384 100644 --- a/tests/factories/provider.py +++ b/tests/factories/provider.py @@ -27,51 +27,51 @@ class FakeProvider: Does not require API keys. """ - PONG_RESPONSE = "LANGBOT_FAKE_PONG" + PONG_RESPONSE = 'LANGBOT_FAKE_PONG' def __init__( self, *, - default_response: str = "fake response", + default_response: str = 'fake response', streaming_chunks: list[str] = None, raise_error: Exception = None, captured_requests: list = None, ): self._default_response = default_response - self._streaming_chunks = streaming_chunks or ["fake ", "response"] + self._streaming_chunks = streaming_chunks or ['fake ', 'response'] self._raise_error = raise_error self._captured_requests = captured_requests if captured_requests is not None else [] - def returns(self, text: str) -> "FakeProvider": + def returns(self, text: str) -> 'FakeProvider': """Configure provider to return a specific text response.""" self._default_response = text self._streaming_chunks = [text] return self - def returns_streaming(self, chunks: list[str]) -> "FakeProvider": + def returns_streaming(self, chunks: list[str]) -> 'FakeProvider': """Configure provider to return streaming chunks.""" self._streaming_chunks = chunks - self._default_response = "".join(chunks) + self._default_response = ''.join(chunks) return self - def raises(self, error: Exception) -> "FakeProvider": + def raises(self, error: Exception) -> 'FakeProvider': """Configure provider to raise an error.""" self._raise_error = error return self - def timeout(self) -> "FakeProvider": + def timeout(self) -> 'FakeProvider': """Configure provider to simulate timeout.""" - return self.raises(TimeoutError("Provider timeout")) + return self.raises(TimeoutError('Provider timeout')) - def auth_error(self) -> "FakeProvider": + def auth_error(self) -> 'FakeProvider': """Configure provider to simulate auth error.""" - return self.raises(Exception("Invalid API key")) + return self.raises(Exception('Invalid API key')) - def rate_limit(self) -> "FakeProvider": + def rate_limit(self) -> 'FakeProvider': """Configure provider to simulate rate limit.""" - return self.raises(Exception("Rate limit exceeded")) + return self.raises(Exception('Rate limit exceeded')) - def malformed(self) -> "FakeProvider": + def malformed(self) -> 'FakeProvider': """Configure provider to simulate malformed response.""" self._default_response = None return self @@ -87,7 +87,7 @@ class FakeProvider: def _create_message(self, content: str) -> provider_message.Message: """Create a provider message from text content.""" return provider_message.Message( - role="assistant", + role='assistant', content=content, ) @@ -99,7 +99,7 @@ class FakeProvider: ) -> provider_message.MessageChunk: """Create a provider message chunk.""" return provider_message.MessageChunk( - role="assistant", + role='assistant', content=content, is_final=is_final, msg_sequence=msg_sequence, @@ -116,13 +116,15 @@ class FakeProvider: ) -> provider_message.Message: """Simulate non-streaming LLM invocation.""" # Capture request for assertions - self._captured_requests.append({ - "query_id": query.query_id if query else None, - "model": model.model_entity.name if model and hasattr(model, 'model_entity') else None, - "messages": messages, - "funcs": funcs, - "extra_args": extra_args, - }) + self._captured_requests.append( + { + 'query_id': query.query_id if query else None, + 'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, + 'messages': messages, + 'funcs': funcs, + 'extra_args': extra_args, + } + ) # Simulate error if configured if self._raise_error: @@ -131,7 +133,7 @@ class FakeProvider: # Return response if self._default_response is None: # Malformed response - return provider_message.Message(role="assistant", content=None) + return provider_message.Message(role='assistant', content=None) return self._create_message(self._default_response) @@ -146,14 +148,16 @@ class FakeProvider: ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: """Simulate streaming LLM invocation.""" # Capture request for assertions - self._captured_requests.append({ - "query_id": query.query_id if query else None, - "model": model.model_entity.name if model and hasattr(model, 'model_entity') else None, - "messages": messages, - "funcs": funcs, - "extra_args": extra_args, - "streaming": True, - }) + self._captured_requests.append( + { + 'query_id': query.query_id if query else None, + 'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, + 'messages': messages, + 'funcs': funcs, + 'extra_args': extra_args, + 'streaming': True, + } + ) # Simulate error if configured if self._raise_error: @@ -161,12 +165,12 @@ class FakeProvider: # Yield chunks for i, chunk in enumerate(self._streaming_chunks): - is_final = (i == len(self._streaming_chunks) - 1) + is_final = i == len(self._streaming_chunks) - 1 yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i) def fake_provider( - default_response: str = "fake response", + default_response: str = 'fake response', ) -> FakeProvider: """Create a FakeProvider with optional default response.""" return FakeProvider(default_response=default_response) @@ -202,8 +206,8 @@ def fake_provider_malformed() -> FakeProvider: def fake_model( *, - uuid: str = "test-model-uuid", - name: str = "test-model", + uuid: str = 'test-model-uuid', + name: str = 'test-model', abilities: list[str] = None, provider: FakeProvider = None, ) -> Mock: @@ -212,7 +216,7 @@ def fake_model( model.model_entity = Mock() model.model_entity.uuid = uuid model.model_entity.name = name - model.model_entity.abilities = abilities or ["func_call", "vision"] + model.model_entity.abilities = abilities or ['func_call', 'vision'] model.model_entity.extra_args = {} # Attach fake provider @@ -221,4 +225,4 @@ def fake_model( model.provider = provider - return model \ No newline at end of file + return model diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index a261bc7b8..dfc61335f 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -3,4 +3,4 @@ Integration tests package. These tests validate real system behavior with actual database/network resources. Run with: uv run pytest tests/integration/ -m "not slow" -q -""" \ No newline at end of file +""" diff --git a/tests/integration/api/__init__.py b/tests/integration/api/__init__.py index 999686642..f11571f07 100644 --- a/tests/integration/api/__init__.py +++ b/tests/integration/api/__init__.py @@ -2,4 +2,4 @@ API integration tests package. Tests for HTTP API endpoints using Quart test client. -""" \ No newline at end of file +""" diff --git a/tests/integration/api/test_bots.py b/tests/integration/api/test_bots.py index 578764ee0..0e6854bf9 100644 --- a/tests/integration/api/test_bots.py +++ b/tests/integration/api/test_bots.py @@ -48,6 +48,7 @@ def mock_circular_import_chain(): clear=clear, ): import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401 + yield @@ -56,10 +57,12 @@ def fake_bot_app(): """Create FakeApp with bot services (module scope for reuse).""" app = FakeApp() - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # Auth services app.user_service = Mock() @@ -71,28 +74,29 @@ def fake_bot_app(): # Bot service app.bot_service = Mock() - app.bot_service.get_bots = AsyncMock(return_value=[ - { + app.bot_service.get_bots = AsyncMock( + return_value=[ + { + 'uuid': 'test-bot-uuid', + 'name': 'Test Bot', + 'platform': 'telegram', + 'pipeline_uuid': 'test-pipeline-uuid', + } + ] + ) + app.bot_service.get_runtime_bot_info = AsyncMock( + return_value={ 'uuid': 'test-bot-uuid', 'name': 'Test Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline-uuid', + 'webhook_url': 'https://example.com/webhook/test-bot-uuid', } - ]) - app.bot_service.get_runtime_bot_info = AsyncMock(return_value={ - 'uuid': 'test-bot-uuid', - 'name': 'Test Bot', - 'platform': 'telegram', - 'pipeline_uuid': 'test-pipeline-uuid', - 'webhook_url': 'https://example.com/webhook/test-bot-uuid', - }) + ) app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'}) app.bot_service.update_bot = AsyncMock(return_value={}) app.bot_service.delete_bot = AsyncMock() - app.bot_service.list_event_logs = AsyncMock(return_value=( - [{'uuid': 'log-1', 'message': 'test log'}], - 1 - )) + app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1)) app.bot_service.send_message = AsyncMock() # Platform manager @@ -118,10 +122,7 @@ class TestBotEndpoints: @pytest.mark.asyncio async def test_get_bots_success(self, quart_test_client): """GET /api/v1/platform/bots returns bot list.""" - response = await quart_test_client.get( - '/api/v1/platform/bots', - headers={'Authorization': 'Bearer test_token'} - ) + response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'}) assert response.status_code == 200 data = await response.get_json() @@ -135,7 +136,7 @@ class TestBotEndpoints: response = await quart_test_client.post( '/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'} + json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}, ) assert response.status_code == 200 @@ -147,8 +148,7 @@ class TestBotEndpoints: async def test_get_single_bot_success(self, quart_test_client): """GET /api/v1/platform/bots/{uuid} returns bot with runtime info.""" response = await quart_test_client.get( - '/api/v1/platform/bots/test-bot-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -162,7 +162,7 @@ class TestBotEndpoints: response = await quart_test_client.put( '/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'Updated Bot'} + json={'name': 'Updated Bot'}, ) assert response.status_code == 200 @@ -173,8 +173,7 @@ class TestBotEndpoints: async def test_delete_bot_success(self, quart_test_client): """DELETE /api/v1/platform/bots/{uuid} deletes bot.""" response = await quart_test_client.delete( - '/api/v1/platform/bots/test-bot-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -190,7 +189,7 @@ class TestBotLogsEndpoint: response = await quart_test_client.post( '/api/v1/platform/bots/test-bot-uuid/logs', headers={'Authorization': 'Bearer test_token'}, - json={'from_index': -1, 'max_count': 10} + json={'from_index': -1, 'max_count': 10}, ) assert response.status_code == 200 @@ -213,8 +212,8 @@ class TestBotSendMessageEndpoint: json={ 'target_type': 'person', 'target_id': 'user123', - 'message_chain': [{'type': 'text', 'text': 'Hello'}] - } + 'message_chain': [{'type': 'text', 'text': 'Hello'}], + }, ) assert response.status_code == 200 @@ -228,7 +227,7 @@ class TestBotSendMessageEndpoint: response = await quart_test_client.post( '/api/v1/platform/bots/test-bot-uuid/send_message', headers={'Authorization': 'Bearer test_api_key'}, - json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]} + json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}, ) assert response.status_code == 400 @@ -244,8 +243,8 @@ class TestBotSendMessageEndpoint: json={ 'target_type': 'invalid', 'target_id': 'user123', - 'message_chain': [{'type': 'text', 'text': 'Hello'}] - } + 'message_chain': [{'type': 'text', 'text': 'Hello'}], + }, ) assert response.status_code == 400 diff --git a/tests/integration/api/test_embed.py b/tests/integration/api/test_embed.py index 12d53d42c..5d034d133 100644 --- a/tests/integration/api/test_embed.py +++ b/tests/integration/api/test_embed.py @@ -47,6 +47,7 @@ def mock_circular_import_chain(): clear=clear, ): import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401 + yield @@ -55,10 +56,12 @@ def fake_embed_app(): """Create FakeApp with embed widget services (module scope).""" app = FakeApp() - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # Create mock web_page_bot with valid UUID format mock_bot_entity = Mock() @@ -83,9 +86,7 @@ def fake_embed_app(): # WebSocket proxy bot with adapter mock_websocket_adapter = Mock() - mock_websocket_adapter.get_websocket_messages = Mock(return_value=[ - {'id': 'msg-1', 'content': 'test message'} - ]) + mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}]) mock_websocket_adapter.reset_session = Mock() mock_websocket_adapter.handle_websocket_message = AsyncMock() @@ -117,9 +118,7 @@ class TestEmbedWidgetEndpoint: @pytest.mark.asyncio async def test_get_widget_js_success(self, quart_test_client): """GET /api/v1/embed/{bot_uuid}/widget.js returns JS.""" - response = await quart_test_client.get( - '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js' - ) + response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js') assert response.status_code == 200 assert 'javascript' in response.content_type @@ -127,18 +126,14 @@ class TestEmbedWidgetEndpoint: @pytest.mark.asyncio async def test_get_widget_js_invalid_uuid(self, quart_test_client): """GET widget.js with invalid UUID returns 400.""" - response = await quart_test_client.get( - '/api/v1/embed/invalid-uuid/widget.js' - ) + response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js') assert response.status_code == 400 @pytest.mark.asyncio async def test_get_widget_js_bot_not_found(self, quart_test_client): """GET widget.js for non-existent bot returns 404.""" - response = await quart_test_client.get( - '/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js' - ) + response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js') assert response.status_code == 404 @@ -164,8 +159,7 @@ class TestEmbedTurnstileVerifyEndpoint: async def test_turnstile_verify_no_secret(self, quart_test_client): """POST turnstile verify without secret returns dummy token.""" response = await quart_test_client.post( - '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', - json={'token': 'test-token'} + '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'} ) assert response.status_code == 200 @@ -177,8 +171,7 @@ class TestEmbedTurnstileVerifyEndpoint: async def test_turnstile_verify_invalid_uuid(self, quart_test_client): """POST turnstile verify with invalid UUID returns 400.""" response = await quart_test_client.post( - '/api/v1/embed/invalid-uuid/turnstile/verify', - json={'token': 'test-token'} + '/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'} ) assert response.status_code == 400 @@ -187,8 +180,7 @@ class TestEmbedTurnstileVerifyEndpoint: async def test_turnstile_verify_missing_token(self, quart_test_client): """POST turnstile verify without token returns 400.""" response = await quart_test_client.post( - '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', - json={} + '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={} ) assert response.status_code == 400 @@ -203,7 +195,7 @@ class TestEmbedMessagesEndpoint: """GET messages/person returns messages.""" response = await quart_test_client.get( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person', - headers={'Authorization': 'Bearer 1234567890.dummy'} + headers={'Authorization': 'Bearer 1234567890.dummy'}, ) assert response.status_code == 200 @@ -216,7 +208,7 @@ class TestEmbedMessagesEndpoint: """GET messages/group returns messages.""" response = await quart_test_client.get( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group', - headers={'Authorization': 'Bearer 1234567890.dummy'} + headers={'Authorization': 'Bearer 1234567890.dummy'}, ) assert response.status_code == 200 @@ -226,7 +218,7 @@ class TestEmbedMessagesEndpoint: """GET messages with invalid session_type returns 400.""" response = await quart_test_client.get( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid', - headers={'Authorization': 'Bearer 1234567890.dummy'} + headers={'Authorization': 'Bearer 1234567890.dummy'}, ) assert response.status_code == 400 @@ -241,7 +233,7 @@ class TestEmbedResetEndpoint: """POST reset/person resets session.""" response = await quart_test_client.post( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person', - headers={'Authorization': 'Bearer 1234567890.dummy'} + headers={'Authorization': 'Bearer 1234567890.dummy'}, ) assert response.status_code == 200 @@ -252,8 +244,7 @@ class TestEmbedResetEndpoint: async def test_reset_session_invalid_uuid(self, quart_test_client): """POST reset with invalid UUID returns 400.""" response = await quart_test_client.post( - '/api/v1/embed/invalid-uuid/reset/person', - headers={'Authorization': 'Bearer 1234567890.dummy'} + '/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'} ) assert response.status_code == 400 @@ -269,7 +260,7 @@ class TestEmbedFeedbackEndpoint: response = await quart_test_client.post( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', headers={'Authorization': 'Bearer 1234567890.dummy'}, - json={'message_id': 'msg-123', 'feedback_type': 1} + json={'message_id': 'msg-123', 'feedback_type': 1}, ) assert response.status_code == 200 @@ -283,7 +274,7 @@ class TestEmbedFeedbackEndpoint: response = await quart_test_client.post( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', headers={'Authorization': 'Bearer 1234567890.dummy'}, - json={'message_id': 'msg-123', 'feedback_type': 2} + json={'message_id': 'msg-123', 'feedback_type': 2}, ) assert response.status_code == 200 @@ -294,7 +285,7 @@ class TestEmbedFeedbackEndpoint: response = await quart_test_client.post( '/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback', headers={'Authorization': 'Bearer 1234567890.dummy'}, - json={'message_id': 'msg-123', 'feedback_type': 99} + json={'message_id': 'msg-123', 'feedback_type': 99}, ) assert response.status_code == 400 diff --git a/tests/integration/api/test_knowledge.py b/tests/integration/api/test_knowledge.py index 9c6935fbb..973356c3e 100644 --- a/tests/integration/api/test_knowledge.py +++ b/tests/integration/api/test_knowledge.py @@ -49,6 +49,7 @@ def mock_circular_import_chain(): clear=clear, ): import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401 + yield @@ -57,10 +58,12 @@ def fake_knowledge_app(): """Create FakeApp with knowledge services (module scope for reuse).""" app = FakeApp() - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # Auth services app.user_service = Mock() @@ -72,33 +75,35 @@ def fake_knowledge_app(): # Knowledge service app.knowledge_service = Mock() - app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[ - { + app.knowledge_service.get_knowledge_bases = AsyncMock( + return_value=[ + { + 'uuid': 'test-kb-uuid', + 'name': 'Test Knowledge Base', + 'description': 'Test KB description', + 'engine_plugin_id': 'test/engine', + 'created_at': '2024-01-01T00:00:00', + 'updated_at': '2024-01-01T00:00:00', + } + ] + ) + app.knowledge_service.get_knowledge_base = AsyncMock( + return_value={ 'uuid': 'test-kb-uuid', 'name': 'Test Knowledge Base', 'description': 'Test KB description', 'engine_plugin_id': 'test/engine', - 'created_at': '2024-01-01T00:00:00', - 'updated_at': '2024-01-01T00:00:00', } - ]) - app.knowledge_service.get_knowledge_base = AsyncMock(return_value={ - 'uuid': 'test-kb-uuid', - 'name': 'Test Knowledge Base', - 'description': 'Test KB description', - 'engine_plugin_id': 'test/engine', - }) + ) app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'}) app.knowledge_service.update_knowledge_base = AsyncMock(return_value={}) app.knowledge_service.delete_knowledge_base = AsyncMock() - app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[ - {'uuid': 'test-file-uuid', 'filename': 'test.pdf'} - ]) + app.knowledge_service.get_files_by_knowledge_base = AsyncMock( + return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}] + ) app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'}) app.knowledge_service.delete_file = AsyncMock() - app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[ - {'content': 'test result', 'score': 0.95} - ]) + app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}]) # RAG manager app.rag_mgr = Mock() @@ -124,8 +129,7 @@ class TestKnowledgeBaseEndpoints: async def test_get_knowledge_bases_success(self, quart_test_client): """GET /api/v1/knowledge/bases returns knowledge base list.""" response = await quart_test_client.get( - '/api/v1/knowledge/bases', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -140,7 +144,7 @@ class TestKnowledgeBaseEndpoints: response = await quart_test_client.post( '/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New KB', 'engine_plugin_id': 'test/engine'} + json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}, ) assert response.status_code == 200 @@ -152,8 +156,7 @@ class TestKnowledgeBaseEndpoints: async def test_get_single_knowledge_base_success(self, quart_test_client): """GET /api/v1/knowledge/bases/{uuid} returns knowledge base.""" response = await quart_test_client.get( - '/api/v1/knowledge/bases/test-kb-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -167,7 +170,7 @@ class TestKnowledgeBaseEndpoints: response = await quart_test_client.put( '/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'Updated KB'} + json={'name': 'Updated KB'}, ) assert response.status_code == 200 @@ -178,8 +181,7 @@ class TestKnowledgeBaseEndpoints: async def test_delete_knowledge_base_success(self, quart_test_client): """DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base.""" response = await quart_test_client.delete( - '/api/v1/knowledge/bases/test-kb-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -193,8 +195,7 @@ class TestKnowledgeBaseFilesEndpoints: async def test_get_files_success(self, quart_test_client): """GET /api/v1/knowledge/bases/{uuid}/files returns files.""" response = await quart_test_client.get( - '/api/v1/knowledge/bases/test-kb-uuid/files', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -208,7 +209,7 @@ class TestKnowledgeBaseFilesEndpoints: response = await quart_test_client.post( '/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}, - json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'} + json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}, ) assert response.status_code == 200 @@ -220,8 +221,7 @@ class TestKnowledgeBaseFilesEndpoints: async def test_delete_file_from_knowledge_base(self, quart_test_client): """DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}.""" response = await quart_test_client.delete( - '/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint: response = await quart_test_client.post( '/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, - json={'query': 'test query', 'retrieval_settings': {'top_k': 5}} + json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}, ) assert response.status_code == 200 @@ -249,9 +249,7 @@ class TestKnowledgeBaseRetrieveEndpoint: async def test_retrieve_without_query_returns_error(self, quart_test_client): """POST retrieve without query returns 400.""" response = await quart_test_client.post( - '/api/v1/knowledge/bases/test-kb-uuid/retrieve', - headers={'Authorization': 'Bearer test_token'}, - json={} + '/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={} ) assert response.status_code == 400 diff --git a/tests/integration/api/test_monitoring.py b/tests/integration/api/test_monitoring.py index 8291bcd13..6a65790ff 100644 --- a/tests/integration/api/test_monitoring.py +++ b/tests/integration/api/test_monitoring.py @@ -46,6 +46,7 @@ def mock_circular_import_chain(): clear=clear, ): import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401 + yield @@ -54,10 +55,12 @@ def fake_monitoring_app(): """Create FakeApp with monitoring services (module scope).""" app = FakeApp() - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email app.user_service = Mock() @@ -67,40 +70,34 @@ def fake_monitoring_app(): # Monitoring service app.monitoring_service = Mock() - app.monitoring_service.get_overview_metrics = AsyncMock(return_value={ - 'total_messages': 100, - 'total_llm_calls': 50, - 'total_sessions': 20, - 'active_sessions': 5, - 'total_errors': 2, - }) - app.monitoring_service.get_messages = AsyncMock(return_value=( - [{'id': 'msg-1', 'content': 'test'}], 100 - )) - app.monitoring_service.get_llm_calls = AsyncMock(return_value=( - [{'id': 'llm-1'}], 50 - )) - app.monitoring_service.get_embedding_calls = AsyncMock(return_value=( - [{'id': 'emb-1'}], 10 - )) - app.monitoring_service.get_sessions = AsyncMock(return_value=( - [{'session_id': 'sess-1'}], 20 - )) - app.monitoring_service.get_errors = AsyncMock(return_value=( - [{'id': 'err-1'}], 2 - )) - app.monitoring_service.get_session_analysis = AsyncMock(return_value={ - 'found': True, - 'session_id': 'sess-1', - }) - app.monitoring_service.get_message_details = AsyncMock(return_value={ - 'found': True, - 'message_id': 'msg-1', - }) + app.monitoring_service.get_overview_metrics = AsyncMock( + return_value={ + 'total_messages': 100, + 'total_llm_calls': 50, + 'total_sessions': 20, + 'active_sessions': 5, + 'total_errors': 2, + } + ) + app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100)) + app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50)) + app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10)) + app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20)) + app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2)) + app.monitoring_service.get_session_analysis = AsyncMock( + return_value={ + 'found': True, + 'session_id': 'sess-1', + } + ) + app.monitoring_service.get_message_details = AsyncMock( + return_value={ + 'found': True, + 'message_id': 'msg-1', + } + ) app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10}) - app.monitoring_service.get_feedback_list = AsyncMock(return_value=( - [{'feedback_id': 'fb-1'}], 12 - )) + app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12)) app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}]) app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}]) app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}]) @@ -130,8 +127,7 @@ class TestMonitoringOverviewEndpoint: async def test_get_overview_success(self, quart_test_client): """GET /api/v1/monitoring/overview returns metrics.""" response = await quart_test_client.get( - '/api/v1/monitoring/overview', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -147,8 +143,7 @@ class TestMonitoringMessagesEndpoint: async def test_get_messages_success(self, quart_test_client): """GET /api/v1/monitoring/messages returns message list.""" response = await quart_test_client.get( - '/api/v1/monitoring/messages', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -165,8 +160,7 @@ class TestMonitoringLLMCallsEndpoint: async def test_get_llm_calls_success(self, quart_test_client): """GET /api/v1/monitoring/llm-calls.""" response = await quart_test_client.get( - '/api/v1/monitoring/llm-calls', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -180,8 +174,7 @@ class TestMonitoringEmbeddingCallsEndpoint: async def test_get_embedding_calls_success(self, quart_test_client): """GET /api/v1/monitoring/embedding-calls.""" response = await quart_test_client.get( - '/api/v1/monitoring/embedding-calls', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -195,8 +188,7 @@ class TestMonitoringSessionsEndpoint: async def test_get_sessions_success(self, quart_test_client): """GET /api/v1/monitoring/sessions.""" response = await quart_test_client.get( - '/api/v1/monitoring/sessions', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -210,8 +202,7 @@ class TestMonitoringErrorsEndpoint: async def test_get_errors_success(self, quart_test_client): """GET /api/v1/monitoring/errors.""" response = await quart_test_client.get( - '/api/v1/monitoring/errors', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -225,8 +216,7 @@ class TestMonitoringAllDataEndpoint: async def test_get_all_data_success(self, quart_test_client): """GET /api/v1/monitoring/data returns all data.""" response = await quart_test_client.get( - '/api/v1/monitoring/data', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -242,8 +232,7 @@ class TestMonitoringDetailsEndpoints: async def test_get_session_analysis(self, quart_test_client): """GET /api/v1/monitoring/sessions/{id}/analysis.""" response = await quart_test_client.get( - '/api/v1/monitoring/sessions/sess-1/analysis', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -252,8 +241,7 @@ class TestMonitoringDetailsEndpoints: async def test_get_message_details(self, quart_test_client): """GET /api/v1/monitoring/messages/{id}/details.""" response = await quart_test_client.get( - '/api/v1/monitoring/messages/msg-1/details', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -267,8 +255,7 @@ class TestMonitoringFeedbackEndpoints: async def test_get_feedback_stats(self, quart_test_client): """GET /api/v1/monitoring/feedback/stats.""" response = await quart_test_client.get( - '/api/v1/monitoring/feedback/stats', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -277,8 +264,7 @@ class TestMonitoringFeedbackEndpoints: async def test_get_feedback_list(self, quart_test_client): """GET /api/v1/monitoring/feedback.""" response = await quart_test_client.get( - '/api/v1/monitoring/feedback', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -292,8 +278,7 @@ class TestMonitoringExportEndpoint: async def test_export_messages(self, quart_test_client): """GET export?type=messages returns CSV.""" response = await quart_test_client.get( - '/api/v1/monitoring/export?type=messages', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -303,8 +288,7 @@ class TestMonitoringExportEndpoint: async def test_export_llm_calls(self, quart_test_client): """GET export?type=llm-calls returns CSV.""" response = await quart_test_client.get( - '/api/v1/monitoring/export?type=llm-calls', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -313,8 +297,7 @@ class TestMonitoringExportEndpoint: async def test_export_sessions(self, quart_test_client): """GET export?type=sessions returns CSV.""" response = await quart_test_client.get( - '/api/v1/monitoring/export?type=sessions', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -323,8 +306,7 @@ class TestMonitoringExportEndpoint: async def test_export_feedback(self, quart_test_client): """GET export?type=feedback returns CSV.""" response = await quart_test_client.get( - '/api/v1/monitoring/export?type=feedback', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 diff --git a/tests/integration/api/test_providers.py b/tests/integration/api/test_providers.py index 4dfa862e8..a42a99428 100644 --- a/tests/integration/api/test_providers.py +++ b/tests/integration/api/test_providers.py @@ -49,6 +49,7 @@ def mock_circular_import_chain(): ): import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401 import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401 + yield @@ -57,10 +58,12 @@ def fake_provider_app(): """Create FakeApp with provider/model services (module scope for reuse).""" app = FakeApp() - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # Auth services app.user_service = Mock() @@ -72,27 +75,23 @@ def fake_provider_app(): # Provider service app.provider_service = Mock() - app.provider_service.get_providers = AsyncMock(return_value=[ - {'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'} - ]) - app.provider_service.get_provider = AsyncMock(return_value={ - 'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl' - }) + app.provider_service.get_providers = AsyncMock( + return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}] + ) + app.provider_service.get_provider = AsyncMock( + return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'} + ) app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid') app.provider_service.update_provider = AsyncMock(return_value={}) app.provider_service.delete_provider = AsyncMock() - app.provider_service.get_provider_model_counts = AsyncMock(return_value={ - 'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0 - }) + app.provider_service.get_provider_model_counts = AsyncMock( + return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0} + ) # LLM model service app.llm_model_service = Mock() - app.llm_model_service.get_llm_models = AsyncMock(return_value=[ - {'uuid': 'test-model-uuid', 'name': 'gpt-4'} - ]) - app.llm_model_service.get_llm_model = AsyncMock(return_value={ - 'uuid': 'test-model-uuid', 'name': 'gpt-4' - }) + app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}]) + app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'}) app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'}) app.llm_model_service.update_llm_model = AsyncMock(return_value={}) app.llm_model_service.delete_llm_model = AsyncMock() @@ -133,8 +132,7 @@ class TestProviderEndpoints: async def test_get_providers_success(self, quart_test_client): """GET /api/v1/provider/providers returns provider list with complete structure.""" response = await quart_test_client.get( - '/api/v1/provider/providers', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -157,8 +155,7 @@ class TestProviderEndpoints: async def test_get_single_provider_success(self, quart_test_client): """GET /api/v1/provider/providers/{uuid} returns complete provider structure.""" response = await quart_test_client.get( - '/api/v1/provider/providers/test-provider-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -177,7 +174,7 @@ class TestProviderEndpoints: response = await quart_test_client.post( '/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New Provider', 'requester': 'chatcmpl'} + json={'name': 'New Provider', 'requester': 'chatcmpl'}, ) assert response.status_code == 200 @@ -194,7 +191,7 @@ class TestProviderEndpoints: response = await quart_test_client.put( '/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'Updated Provider'} + json={'name': 'Updated Provider'}, ) assert response.status_code == 200 @@ -205,8 +202,7 @@ class TestProviderEndpoints: async def test_delete_provider_success(self, quart_test_client): """DELETE /api/v1/provider/providers/{uuid} deletes provider.""" response = await quart_test_client.delete( - '/api/v1/provider/providers/test-provider-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -215,8 +211,7 @@ class TestProviderEndpoints: async def test_get_provider_includes_model_counts(self, quart_test_client): """GET provider response includes model counts.""" response = await quart_test_client.get( - '/api/v1/provider/providers/test-provider-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -237,8 +232,7 @@ class TestModelEndpoints: async def test_get_llm_models_success(self, quart_test_client): """GET /api/v1/provider/models/llm returns model list.""" response = await quart_test_client.get( - '/api/v1/provider/models/llm', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -250,8 +244,7 @@ class TestModelEndpoints: async def test_get_single_llm_model_success(self, quart_test_client): """GET /api/v1/provider/models/llm/{uuid} returns model.""" response = await quart_test_client.get( - '/api/v1/provider/models/llm/test-model-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -264,7 +257,7 @@ class TestModelEndpoints: response = await quart_test_client.post( '/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'} + json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}, ) assert response.status_code == 200 @@ -276,8 +269,7 @@ class TestModelEndpoints: async def test_delete_llm_model_success(self, quart_test_client): """DELETE /api/v1/provider/models/llm/{uuid} deletes model.""" response = await quart_test_client.delete( - '/api/v1/provider/models/llm/test-model-uuid', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -291,8 +283,7 @@ class TestEmbeddingModelEndpoints: async def test_get_embedding_models_success(self, quart_test_client): """GET /api/v1/provider/models/embedding returns model list.""" response = await quart_test_client.get( - '/api/v1/provider/models/embedding', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -306,7 +297,7 @@ class TestEmbeddingModelEndpoints: response = await quart_test_client.post( '/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'} + json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}, ) assert response.status_code == 200 @@ -323,8 +314,7 @@ class TestRerankModelEndpoints: async def test_get_rerank_models_success(self, quart_test_client): """GET /api/v1/provider/models/rerank returns model list.""" response = await quart_test_client.get( - '/api/v1/provider/models/rerank', - headers={'Authorization': 'Bearer test_token'} + '/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'} ) assert response.status_code == 200 @@ -338,7 +328,7 @@ class TestRerankModelEndpoints: response = await quart_test_client.post( '/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}, - json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'} + json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}, ) assert response.status_code == 200 diff --git a/tests/integration/api/test_smoke.py b/tests/integration/api/test_smoke.py index 460db55bd..9f611bb6c 100644 --- a/tests/integration/api/test_smoke.py +++ b/tests/integration/api/test_smoke.py @@ -20,6 +20,7 @@ pytestmark = pytest.mark.integration # ============== FIXTURE FOR SYS.MODULES ISOLATION ============== + @pytest.fixture(scope='module') def mock_circular_import_chain(): """ @@ -69,6 +70,7 @@ def mock_circular_import_chain(): # ============== FAKE APPLICATION FOR API TESTS ============== + @pytest.fixture def fake_api_app(): """ @@ -79,12 +81,14 @@ def fake_api_app(): app = FakeApp() # API-specific config - app.instance_config.data.update({ - 'api': {'port': 5300}, - 'plugin': {'enable_marketplace': True}, - 'space': {'url': 'https://space.langbot.app'}, - 'system': {'allow_modify_login_info': True, 'limitation': {}}, - }) + app.instance_config.data.update( + { + 'api': {'port': 5300}, + 'plugin': {'enable_marketplace': True}, + 'space': {'url': 'https://space.langbot.app'}, + 'system': {'allow_modify_login_info': True, 'limitation': {}}, + } + ) # API-specific services app.user_service = Mock() @@ -118,6 +122,7 @@ def fake_api_app(): # ============== QUART TEST CLIENT FIXTURE ============== + @pytest.fixture async def quart_test_client(fake_api_app, http_controller_cls): """ @@ -135,6 +140,7 @@ async def quart_test_client(fake_api_app, http_controller_cls): # ============== API SMOKE TESTS ============== + @pytest.mark.usefixtures('mock_circular_import_chain') class TestHealthEndpoint: """Tests for /healthz endpoint - simplest smoke test.""" @@ -222,8 +228,7 @@ class TestProtectedEndpoints: Protected endpoint returns 401 with invalid token. """ response = await quart_test_client.get( - '/api/v1/user/check-token', - headers={'Authorization': 'Bearer invalid_token'} + '/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'} ) assert response.status_code == 401 @@ -254,10 +259,7 @@ class TestInvalidPayload: """ POST with wrong JSON structure returns stable error. """ - response = await quart_test_client.post( - '/api/v1/user/auth', - json={'wrong_field': 'value'} - ) + response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'}) # Should return error with stable JSON structure assert response.status_code in (400, 500, 401) diff --git a/tests/integration/persistence/__init__.py b/tests/integration/persistence/__init__.py index 496ef8684..0cab56344 100644 --- a/tests/integration/persistence/__init__.py +++ b/tests/integration/persistence/__init__.py @@ -2,4 +2,4 @@ Persistence integration tests package. Tests for database migrations and storage behavior. -""" \ No newline at end of file +""" diff --git a/tests/integration/persistence/test_migrations.py b/tests/integration/persistence/test_migrations.py index be3427a53..f9872f829 100644 --- a/tests/integration/persistence/test_migrations.py +++ b/tests/integration/persistence/test_migrations.py @@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration @pytest.fixture def sqlite_db_url(tmp_path): """Create SQLite URL with temporary database file.""" - db_file = tmp_path / "test_migrations.db" - return f"sqlite+aiosqlite:///{db_file}" + db_file = tmp_path / 'test_migrations.db' + return f'sqlite+aiosqlite:///{db_file}' @pytest.fixture @@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade: # Verify revision rev = await get_alembic_current(sqlite_engine) - assert rev is not None, "Expected a revision after upgrade" + assert rev is not None, 'Expected a revision after upgrade' # Head should be the latest migration - assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}" + assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}' @pytest.mark.asyncio async def test_upgrade_idempotent(self, sqlite_engine): @@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade: await run_alembic_upgrade(sqlite_engine, 'head') rev2 = await get_alembic_current(sqlite_engine) - assert rev2 == rev1, f"Expected {rev1}, got {rev2}" + assert rev2 == rev1, f'Expected {rev1}, got {rev2}' class TestSQLiteMigrationFreshDatabase: @@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase: 4. Verify revision """ # Use different DB file for fresh test - fresh_db_file = tmp_path / "test_migrations_fresh.db" - fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}" + fresh_db_file = tmp_path / 'test_migrations_fresh.db' + fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}' fresh_engine = create_async_engine(fresh_url) # Create tables on fresh DB @@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase: # Verify revision rev = await get_alembic_current(fresh_engine) - assert rev is not None, "Expected a revision on fresh DB" + assert rev is not None, 'Expected a revision on fresh DB' await fresh_engine.dispose() @@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase: IMPORTANT: This test verifies the ACTUAL behavior, not accepting any arbitrary failure with try-except pass. """ - fresh_db_file = tmp_path / "test_empty_migrations.db" - fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}" + fresh_db_file = tmp_path / 'test_empty_migrations.db' + fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}' fresh_engine = create_async_engine(fresh_url) # Capture the actual behavior @@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase: # Verify specific behavior - one of two outcomes is expected if actual_result is not None: # Migration succeeded - verify revision exists - assert actual_result is not None, "Revision should exist after successful migration" + assert actual_result is not None, 'Revision should exist after successful migration' else: # Migration failed - verify the error type is known # Alembic typically raises specific errors for missing tables - assert actual_error is not None, "Error should be captured if migration failed" + assert actual_error is not None, 'Error should be captured if migration failed' # Log the error type for documentation (don't silently pass) error_type = type(actual_error).__name__ # Acceptable error types for empty DB scenarios acceptable_errors = [ 'OperationalError', # SQLite table not found 'ProgrammingError', # SQLAlchemy errors - 'CommandError', # Alembic command errors + 'CommandError', # Alembic command errors ] assert error_type in acceptable_errors, ( - f"Unexpected error type: {error_type}. " - f"This may indicate a regression in migration behavior. " - f"Error: {actual_error}" + f'Unexpected error type: {error_type}. ' + f'This may indicate a regression in migration behavior. ' + f'Error: {actual_error}' ) @@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent: # No stamp - should return None rev = await get_alembic_current(sqlite_engine) - assert rev is None, f"Expected None for unstamped DB, got {rev}" + assert rev is None, f'Expected None for unstamped DB, got {rev}' @pytest.mark.asyncio async def test_get_current_after_stamp_returns_revision(self, sqlite_engine): @@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent: await run_alembic_stamp(sqlite_engine, '0001_baseline') rev = await get_alembic_current(sqlite_engine) - assert rev == '0001_baseline' \ No newline at end of file + assert rev == '0001_baseline' diff --git a/tests/integration/persistence/test_migrations_postgres.py b/tests/integration/persistence/test_migrations_postgres.py index 20f892154..28d06a2c8 100644 --- a/tests/integration/persistence/test_migrations_postgres.py +++ b/tests/integration/persistence/test_migrations_postgres.py @@ -34,14 +34,14 @@ def postgres_url(): """Get PostgreSQL URL from environment.""" url = os.environ.get('TEST_POSTGRES_URL') if not url: - pytest.skip("TEST_POSTGRES_URL not set") + pytest.skip('TEST_POSTGRES_URL not set') return url @pytest.fixture async def postgres_engine(postgres_url): """Create async PostgreSQL engine.""" - engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT") + engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT') yield engine await engine.dispose() @@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine): async with postgres_engine.begin() as conn: # Drop alembic_version table if exists try: - await conn.execute(text("DROP TABLE IF EXISTS alembic_version")) + await conn.execute(text('DROP TABLE IF EXISTS alembic_version')) except Exception: pass @@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine): async with postgres_engine.begin() as conn: try: - await conn.execute(text("DROP TABLE IF EXISTS alembic_version")) + await conn.execute(text('DROP TABLE IF EXISTS alembic_version')) except Exception: pass @@ -83,9 +83,7 @@ class TestPostgreSQLMigrationBaseline: """Tests for baseline stamp workflow on PostgreSQL.""" @pytest.mark.asyncio - async def test_postgres_baseline_stamp_sets_revision( - self, postgres_engine, clean_tables, clean_alembic_version - ): + async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version): """ Stamp baseline on existing tables sets correct revision. @@ -106,9 +104,7 @@ class TestPostgreSQLMigrationBaseline: assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}" @pytest.mark.asyncio - async def test_postgres_baseline_stamp_on_empty_db( - self, postgres_engine, clean_tables, clean_alembic_version - ): + async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version): """ Stamp on empty database (no tables) still sets revision. @@ -125,9 +121,7 @@ class TestPostgreSQLMigrationUpgrade: """Tests for upgrade to head workflow on PostgreSQL.""" @pytest.mark.asyncio - async def test_postgres_upgrade_from_baseline_to_head( - self, postgres_engine, clean_tables, clean_alembic_version - ): + async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version): """ Upgrade from baseline to head applies all migrations. @@ -149,14 +143,12 @@ class TestPostgreSQLMigrationUpgrade: # Verify revision rev = await get_alembic_current(postgres_engine) - assert rev is not None, "Expected a revision after upgrade" + assert rev is not None, 'Expected a revision after upgrade' # Head should be the latest migration (0005 for current state) - assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}" + assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}' @pytest.mark.asyncio - async def test_postgres_upgrade_idempotent( - self, postgres_engine, clean_tables, clean_alembic_version - ): + async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version): """ Running upgrade to head multiple times is idempotent. @@ -180,7 +172,7 @@ class TestPostgreSQLMigrationUpgrade: await run_alembic_upgrade(postgres_engine, 'head') rev2 = await get_alembic_current(postgres_engine) - assert rev2 == rev1, f"Expected {rev1}, got {rev2}" + assert rev2 == rev1, f'Expected {rev1}, got {rev2}' class TestPostgreSQLMigrationGetCurrent: @@ -199,7 +191,7 @@ class TestPostgreSQLMigrationGetCurrent: # No stamp - should return None rev = await get_alembic_current(postgres_engine) - assert rev is None, f"Expected None for unstamped DB, got {rev}" + assert rev is None, f'Expected None for unstamped DB, got {rev}' @pytest.mark.asyncio async def test_postgres_get_current_after_stamp_returns_revision( @@ -214,4 +206,4 @@ class TestPostgreSQLMigrationGetCurrent: await run_alembic_stamp(postgres_engine, '0001_baseline') rev = await get_alembic_current(postgres_engine) - assert rev == '0001_baseline' \ No newline at end of file + assert rev == '0001_baseline' diff --git a/tests/integration/pipeline/__init__.py b/tests/integration/pipeline/__init__.py index 9351eaba7..7cb13296e 100644 --- a/tests/integration/pipeline/__init__.py +++ b/tests/integration/pipeline/__init__.py @@ -2,4 +2,4 @@ Pipeline integration tests package. Tests for full pipeline flow using fake provider/runner. -""" \ No newline at end of file +""" diff --git a/tests/integration/pipeline/test_full_flow.py b/tests/integration/pipeline/test_full_flow.py index 08acce4cc..6aa704436 100644 --- a/tests/integration/pipeline/test_full_flow.py +++ b/tests/integration/pipeline/test_full_flow.py @@ -26,6 +26,7 @@ pytestmark = pytest.mark.integration # ============== FIXTURE FOR SYS.MODULES ISOLATION ============== + @pytest.fixture(scope='module') def mock_circular_import_chain(): """ @@ -103,6 +104,7 @@ def mock_circular_import_chain(): # ============== FAKE RUNNER ============== + class FakeRunner: """Minimal fake runner class for pipeline integration tests. @@ -117,12 +119,13 @@ class FakeRunner: self.config = config or {} self._provider = FakeProvider() # Instance-level configuration set via class attribute - self._response_text = "fake response" + self._response_text = 'fake response' self._raise_error = None @classmethod def returns(cls, text: str): """Create a runner class configured to return specific text.""" + # We create a subclass with configured response class ConfiguredRunner(cls): name = cls.name @@ -132,11 +135,13 @@ class FakeRunner: def __init__(self, app=None, config=None): super().__init__(app, config) self._response_text = text + return ConfiguredRunner @classmethod def raises(cls, error: Exception): """Create a runner class configured to raise an error.""" + class ConfiguredRunner(cls): name = cls.name _response_text = None @@ -145,6 +150,7 @@ class FakeRunner: def __init__(self, app=None, config=None): super().__init__(app, config) self._raise_error = error + return ConfiguredRunner async def run(self, query): @@ -161,6 +167,7 @@ class FakeRunner: # ============== PIPELINE APP FIXTURE ============== + @pytest.fixture def pipeline_app(): """ @@ -187,6 +194,7 @@ def pipeline_app(): def __init__(self, name, messages): self.name = name self.messages = messages + def copy(self): return MockPrompt(self.name, list(self.messages)) @@ -237,14 +245,17 @@ def fake_platform_adapter(): @pytest.fixture def set_fake_runner(): """Factory fixture to set a fake runner CLASS in preregistered_runners.""" + def _set_runner(runner_cls): # preregistered_runners expects a list of runner classes sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls] + return _set_runner # ============== PIPELINE CONFIGURATION ============== + def create_minimal_pipeline_config(): """Create minimal pipeline configuration for tests.""" return { @@ -273,6 +284,7 @@ def create_minimal_pipeline_config(): # ============== HELPER TO PROCESS COROUTINE/GENERATOR ============== + async def collect_processor_results(processor, query, stage_name): """ Helper to handle the coroutine -> async_generator pattern. @@ -296,6 +308,7 @@ async def collect_processor_results(processor, query, stage_name): # ============== TESTS ============== + @pytest.mark.usefixtures('mock_circular_import_chain') class TestPipelineStageChainReal: """Tests for real pipeline stage chain.""" @@ -337,7 +350,7 @@ class TestPreProcessorStage: adapter, platform = fake_platform_adapter # Create query with adapter - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() @@ -365,7 +378,7 @@ class TestPreProcessorStage: adapter, platform = fake_platform_adapter - query = text_query("test message content") + query = text_query('test message content') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() @@ -396,11 +409,11 @@ class TestProcessorStage: adapter, platform = fake_platform_adapter # Set fake runner that returns pong - fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG") + fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG') set_fake_runner(fake_runner) # Create query - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() query.resp_messages = [] @@ -414,6 +427,7 @@ class TestProcessorStage: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -432,7 +446,7 @@ class TestProcessorStage: adapter, platform = fake_platform_adapter # Create query - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() @@ -445,6 +459,7 @@ class TestProcessorStage: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -462,13 +477,13 @@ class TestProcessorStage: adapter, platform = fake_platform_adapter # Create query - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() query.resp_messages = [] # Create reply chain - reply_chain = text_chain("plugin response") + reply_chain = text_chain('plugin response') # Mock plugin_connector to prevent default with reply mock_event_ctx = Mock() @@ -479,6 +494,7 @@ class TestProcessorStage: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -502,7 +518,7 @@ class TestRunnerExceptionFlow: adapter, platform = fake_platform_adapter # Set fake runner that raises exception - fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded")) + fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded')) set_fake_runner(fake_runner) # Create query with exception handling config @@ -510,7 +526,7 @@ class TestRunnerExceptionFlow: config['output']['misc']['exception-handling'] = 'show-hint' config['output']['misc']['failure-hint'] = 'Request failed.' - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = config @@ -523,6 +539,7 @@ class TestRunnerExceptionFlow: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -541,14 +558,14 @@ class TestRunnerExceptionFlow: adapter, platform = fake_platform_adapter # Set fake runner that raises specific exception - fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error")) + fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error')) set_fake_runner(fake_runner) # Create query with show-error mode config = create_minimal_pipeline_config() config['output']['misc']['exception-handling'] = 'show-error' - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = config @@ -561,6 +578,7 @@ class TestRunnerExceptionFlow: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -578,14 +596,14 @@ class TestRunnerExceptionFlow: adapter, platform = fake_platform_adapter # Set fake runner that raises exception - fake_runner = FakeRunner().raises(Exception("Hidden error")) + fake_runner = FakeRunner().raises(Exception('Hidden error')) set_fake_runner(fake_runner) # Create query with hide mode config = create_minimal_pipeline_config() config['output']['misc']['exception-handling'] = 'hide' - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = config @@ -598,6 +616,7 @@ class TestRunnerExceptionFlow: # Create Processor stage from langbot.pkg.pipeline.process import process + processor_stage = process.Processor(pipeline_app) await processor_stage.initialize(query.pipeline_config) @@ -623,7 +642,7 @@ class TestSendResponseBackStage: adapter, platform = fake_platform_adapter # Create query with response message - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() @@ -666,12 +685,12 @@ class TestStageChainIntegration: adapter, platform = fake_platform_adapter # Set fake runner - fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG") + fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG') set_fake_runner(fake_runner) # Create query config = create_minimal_pipeline_config() - query = text_query("ping") + query = text_query('ping') query.adapter = adapter query.pipeline_config = config query.resp_messages = [] @@ -690,7 +709,7 @@ class TestStageChainIntegration: pipeline_app.plugin_connector.emit_event = AsyncMock() pipeline_app.plugin_connector.emit_event.side_effect = [ - mock_event_ctx_preproc, # PreProcessor PromptPreProcessing + mock_event_ctx_preproc, # PreProcessor PromptPreProcessing mock_event_ctx_processor, # Processor NormalMessageReceived ] @@ -711,6 +730,7 @@ class TestStageChainIntegration: # Build resp_message_chain from resp_messages from tests.factories.message import text_chain + for resp_msg in query.resp_messages: if resp_msg.content: query.resp_message_chain.append(text_chain(resp_msg.content)) @@ -737,7 +757,7 @@ class TestStageChainIntegration: adapter, platform = fake_platform_adapter # Create query - query = text_query("hello") + query = text_query('hello') query.adapter = adapter query.pipeline_config = create_minimal_pipeline_config() @@ -754,7 +774,7 @@ class TestStageChainIntegration: pipeline_app.plugin_connector.emit_event = AsyncMock() pipeline_app.plugin_connector.emit_event.side_effect = [ - mock_event_ctx_preproc, # PreProcessor PromptPreProcessing + mock_event_ctx_preproc, # PreProcessor PromptPreProcessing mock_event_ctx_processor, # Processor NormalMessageReceived ] @@ -775,4 +795,4 @@ class TestStageChainIntegration: assert results[0].result_type == entities.ResultType.INTERRUPT # Chain stops here - no resp_messages - assert len(query.resp_messages) == 0 \ No newline at end of file + assert len(query.resp_messages) == 0 diff --git a/tests/smoke/__init__.py b/tests/smoke/__init__.py index 5f7e6721b..a4634f28c 100644 --- a/tests/smoke/__init__.py +++ b/tests/smoke/__init__.py @@ -3,4 +3,4 @@ Smoke tests package. Smoke tests verify basic functionality works without testing edge cases. Run with: uv run pytest tests/smoke/ -q -""" \ No newline at end of file +""" diff --git a/tests/smoke/test_fake_message_flow.py b/tests/smoke/test_fake_message_flow.py index aa1bf827d..5ae195f2f 100644 --- a/tests/smoke/test_fake_message_flow.py +++ b/tests/smoke/test_fake_message_flow.py @@ -39,19 +39,19 @@ class TestFakeMessageFlow: assert app.instance_config is not None # Verify default config - assert app.instance_config.data["command"]["prefix"] == ["/", "!"] - assert app.instance_config.data["command"]["enable"] is True + assert app.instance_config.data['command']['prefix'] == ['/', '!'] + assert app.instance_config.data['command']['enable'] is True @pytest.mark.asyncio async def test_fake_provider_returns_text(self): """Test FakeProvider returns configured response.""" - provider = FakeProvider(default_response="test response") + provider = FakeProvider(default_response='test response') # Create mock model with provider model = fake_model(provider=provider) # Create a simple query - query = text_query("hello") + query = text_query('hello') # Simulate invoke result = await provider.invoke_llm( @@ -63,15 +63,15 @@ class TestFakeMessageFlow: ) assert result is not None - assert result.role == "assistant" - assert result.content == "test response" + assert result.role == 'assistant' + assert result.content == 'test response' @pytest.mark.asyncio async def test_fake_provider_pong(self): """Test FakeProvider returns LANGBOT_FAKE_PONG marker.""" provider = fake_provider_pong() model = fake_model(provider=provider) - query = text_query("ping") + query = text_query('ping') result = await provider.invoke_llm( query=query, @@ -86,9 +86,9 @@ class TestFakeMessageFlow: @pytest.mark.asyncio async def test_fake_provider_streaming(self): """Test FakeProvider streaming response.""" - provider = FakeProvider().returns_streaming(["Hello", " World"]) + provider = FakeProvider().returns_streaming(['Hello', ' World']) model = fake_model(provider=provider) - query = text_query("hello") + query = text_query('hello') chunks = [] # invoke_llm_stream returns an async generator, don't await it @@ -102,8 +102,8 @@ class TestFakeMessageFlow: chunks.append(chunk) assert len(chunks) == 2 - assert chunks[0].content == "Hello" - assert chunks[1].content == " World" + assert chunks[0].content == 'Hello' + assert chunks[1].content == ' World' assert chunks[1].is_final is True @pytest.mark.asyncio @@ -111,9 +111,9 @@ class TestFakeMessageFlow: """Test FakeProvider simulates timeout error.""" provider = FakeProvider().timeout() model = fake_model(provider=provider) - query = text_query("hello") + query = text_query('hello') - with pytest.raises(TimeoutError, match="Provider timeout"): + with pytest.raises(TimeoutError, match='Provider timeout'): await provider.invoke_llm( query=query, model=model, @@ -127,9 +127,9 @@ class TestFakeMessageFlow: """Test FakeProvider simulates rate limit error.""" provider = FakeProvider().rate_limit() model = fake_model(provider=provider) - query = text_query("hello") + query = text_query('hello') - with pytest.raises(Exception, match="Rate limit exceeded"): + with pytest.raises(Exception, match='Rate limit exceeded'): await provider.invoke_llm( query=query, model=model, @@ -142,34 +142,34 @@ class TestFakeMessageFlow: async def test_fake_provider_captures_requests(self): """Test FakeProvider captures request arguments.""" provider = FakeProvider() - model = fake_model(name="gpt-4", provider=provider) - query = text_query("hello") + model = fake_model(name='gpt-4', provider=provider) + query = text_query('hello') await provider.invoke_llm( query=query, model=model, - messages=[{"role": "user", "content": "hello"}], - funcs=[{"name": "test_func"}], - extra_args={"temperature": 0.7}, + messages=[{'role': 'user', 'content': 'hello'}], + funcs=[{'name': 'test_func'}], + extra_args={'temperature': 0.7}, ) captured = provider.get_captured_requests() assert len(captured) == 1 - assert captured[0]["model"] == "gpt-4" - assert captured[0]["messages"] == [{"role": "user", "content": "hello"}] - assert captured[0]["funcs"] == [{"name": "test_func"}] - assert captured[0]["extra_args"] == {"temperature": 0.7} + assert captured[0]['model'] == 'gpt-4' + assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}] + assert captured[0]['funcs'] == [{'name': 'test_func'}] + assert captured[0]['extra_args'] == {'temperature': 0.7} @pytest.mark.asyncio async def test_fake_platform_capture_outbound(self): """Test FakePlatform captures outbound messages.""" - platform = FakePlatform(bot_account_id="test-bot") - query = text_query("hello") + platform = FakePlatform(bot_account_id='test-bot') + query = text_query('hello') # Simulate sending reply from tests.factories.message import text_chain - reply_chain = text_chain("response text") + reply_chain = text_chain('response text') event = query.message_event await platform.reply_message(event, reply_chain, quote_origin=False) @@ -177,38 +177,38 @@ class TestFakeMessageFlow: # Verify captured outbound = platform.get_outbound_messages() assert len(outbound) == 1 - assert outbound[0]["type"] == "reply" - assert outbound[0]["message"] == reply_chain + assert outbound[0]['type'] == 'reply' + assert outbound[0]['message'] == reply_chain @pytest.mark.asyncio async def test_fake_platform_friend_message(self): """Test FakePlatform creates friend message events.""" - platform = FakePlatform(bot_account_id="test-bot") + platform = FakePlatform(bot_account_id='test-bot') event = platform.create_friend_message( - text="hello bot", + text='hello bot', sender_id=12345, - nickname="TestUser", + nickname='TestUser', ) - assert event.type == "FriendMessage" + assert event.type == 'FriendMessage' assert event.sender.id == 12345 - assert event.sender.nickname == "TestUser" - assert str(event.message_chain) == "hello bot" + assert event.sender.nickname == 'TestUser' + assert str(event.message_chain) == 'hello bot' @pytest.mark.asyncio async def test_fake_platform_group_message_with_mention(self): """Test FakePlatform creates group message with @mention.""" - platform = FakePlatform(bot_account_id="test-bot") + platform = FakePlatform(bot_account_id='test-bot') event = platform.create_group_message( - text="hello everyone", + text='hello everyone', sender_id=12345, group_id=99999, mention_bot=True, ) - assert event.type == "GroupMessage" + assert event.type == 'GroupMessage' assert event.sender.id == 12345 assert event.group.id == 99999 @@ -220,54 +220,57 @@ class TestFakeMessageFlow: async def test_query_factories_basic(self): """Test basic query factory functions.""" # Text query - q1 = text_query("hello world") - assert q1.launcher_type.value == "person" - assert str(q1.message_chain) == "hello world" + q1 = text_query('hello world') + assert q1.launcher_type.value == 'person' + assert str(q1.message_chain) == 'hello world' # Group query from tests.factories import group_text_query - q2 = group_text_query("hello group", group_id=88888) - assert q2.launcher_type.value == "group" + + q2 = group_text_query('hello group', group_id=88888) + assert q2.launcher_type.value == 'group' assert q2.launcher_id == 88888 # Command query from tests.factories import command_query - q3 = command_query("help", prefix="/") - assert str(q3.message_chain) == "/help" + + q3 = command_query('help', prefix='/') + assert str(q3.message_chain) == '/help' # Mention query from tests.factories import mention_query - q4 = mention_query("hi", target="test-bot", group_id=77777) - assert q4.launcher_type.value == "group" + + q4 = mention_query('hi', target='test-bot', group_id=77777) + assert q4.launcher_type.value == 'group' @pytest.mark.asyncio async def test_fake_platform_send_failure(self): """Test FakePlatform simulates send failure.""" platform = FakePlatform().send_failure() - query = text_query("hello") + query = text_query('hello') from tests.factories.message import text_chain - with pytest.raises(Exception, match="Platform send failure"): + with pytest.raises(Exception, match='Platform send failure'): await platform.reply_message( query.message_event, - text_chain("response"), + text_chain('response'), ) @pytest.mark.asyncio async def test_mock_platform_adapter(self): """Test mock_platform_adapter helper.""" - platform = FakePlatform(bot_account_id="bot-123") + platform = FakePlatform(bot_account_id='bot-123') adapter = mock_platform_adapter(platform) - assert adapter.bot_account_id == "bot-123" + assert adapter.bot_account_id == 'bot-123' assert adapter._fake_platform is platform # Test reply_message is wired from tests.factories.message import text_chain - query = text_query("test") - await adapter.reply_message(query.message_event, text_chain("response")) + query = text_query('test') + await adapter.reply_message(query.message_event, text_chain('response')) # Verify platform captured it assert len(platform.get_outbound_messages()) == 1 @@ -293,18 +296,18 @@ class TestMessageFlowIntegration: Note: This does NOT run actual LangBot pipeline stages. """ # Setup - platform = FakePlatform(bot_account_id="test-bot") + platform = FakePlatform(bot_account_id='test-bot') provider = fake_provider_pong() model = fake_model(provider=provider) # Create inbound message - query = text_query("ping") + query = text_query('ping') # Simulate provider processing response = await provider.invoke_llm( query=query, model=model, - messages=[{"role": "user", "content": "ping"}], + messages=[{'role': 'user', 'content': 'ping'}], funcs=[], extra_args={}, ) @@ -321,16 +324,16 @@ class TestMessageFlowIntegration: # Verify platform captured outbound outbound = platform.get_outbound_messages() assert len(outbound) == 1 - assert outbound[0]["type"] == "reply" - assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE + assert outbound[0]['type'] == 'reply' + assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE @pytest.mark.asyncio async def test_streaming_message_flow(self): """Smoke test: streaming message flow.""" platform = FakePlatform().supports_streaming() - provider = FakeProvider().returns_streaming(["Hello", " there"]) + provider = FakeProvider().returns_streaming(['Hello', ' there']) model = fake_model(provider=provider) - query = text_query("hi") + query = text_query('hi') chunks = [] async for chunk in provider.invoke_llm_stream( @@ -344,8 +347,8 @@ class TestMessageFlowIntegration: # Verify streaming worked assert len(chunks) == 2 - full_content = "".join(c.content for c in chunks) - assert full_content == "Hello there" + full_content = ''.join(c.content for c in chunks) + assert full_content == 'Hello there' # Verify platform supports streaming - assert await platform.is_stream_output_supported() is True \ No newline at end of file + assert await platform.is_stream_output_supported() is True diff --git a/tests/test_cwe94_debug_exec.py b/tests/test_cwe94_debug_exec.py index 48e08d1a2..b31085981 100644 --- a/tests/test_cwe94_debug_exec.py +++ b/tests/test_cwe94_debug_exec.py @@ -15,22 +15,12 @@ import pathlib # Resolve project root (one level up from tests/) _PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent -VULN_FILE = ( - _PROJECT_ROOT - / "src" - / "langbot" - / "pkg" - / "api" - / "http" - / "controller" - / "groups" - / "system.py" -) +VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py' def test_no_exec_call_in_system_controller(): """Verify there is no exec() call in system.py that takes user input.""" - with open(VULN_FILE, "r") as f: + with open(VULN_FILE, 'r') as f: source = f.read() tree = ast.parse(source) @@ -40,27 +30,26 @@ def test_no_exec_call_in_system_controller(): if isinstance(node, ast.Call): func = node.func # Match bare exec() call - if isinstance(func, ast.Name) and func.id == "exec": + if isinstance(func, ast.Name) and func.id == 'exec': exec_calls.append(node.lineno) assert len(exec_calls) == 0, ( - f"Found exec() call(s) at line(s) {exec_calls} in system.py. " - "User-supplied code must never be passed to exec()." + f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().' ) def test_no_debug_exec_route(): """Verify the /debug/exec route is not registered.""" - with open(VULN_FILE, "r") as f: + with open(VULN_FILE, 'r') as f: source = f.read() - assert "debug/exec" not in source, ( - "The /debug/exec route still exists in system.py. " - "This endpoint allows arbitrary code execution and must be removed." + assert 'debug/exec' not in source, ( + 'The /debug/exec route still exists in system.py. ' + 'This endpoint allows arbitrary code execution and must be removed.' ) -if __name__ == "__main__": +if __name__ == '__main__': test_no_exec_call_in_system_controller() test_no_debug_exec_route() - print("All tests passed!") + print('All tests passed!') diff --git a/tests/unit_tests/api/__init__.py b/tests/unit_tests/api/__init__.py index d8628d82d..42c4689ce 100644 --- a/tests/unit_tests/api/__init__.py +++ b/tests/unit_tests/api/__init__.py @@ -1 +1 @@ -"""Unit tests for LangBot API HTTP service layer.""" \ No newline at end of file +"""Unit tests for LangBot API HTTP service layer.""" diff --git a/tests/unit_tests/api/service/__init__.py b/tests/unit_tests/api/service/__init__.py index 67828f4d8..7d53c4c5f 100644 --- a/tests/unit_tests/api/service/__init__.py +++ b/tests/unit_tests/api/service/__init__.py @@ -13,4 +13,4 @@ Does NOT: - Call real provider/platform/network Uses tests.factories.FakeApp as base mock application. -""" \ No newline at end of file +""" diff --git a/tests/unit_tests/api/service/test_apikey_service.py b/tests/unit_tests/api/service/test_apikey_service.py index e71879874..287a21bad 100644 --- a/tests/unit_tests/api/service/test_apikey_service.py +++ b/tests/unit_tests/api/service/test_apikey_service.py @@ -132,9 +132,7 @@ class TestApiKeyServiceCreateApiKey: with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'): result = await service.create_api_key('New Key', 'Test description') - assert insert_params == [ - {'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'} - ] + assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}] assert result['key'].startswith('lbk_') assert result['key'] == 'lbk_fixed-token' assert result['name'] == 'New Key' diff --git a/tests/unit_tests/api/service/test_bot_service.py b/tests/unit_tests/api/service/test_bot_service.py index c1e5abfe6..8a6d0ad2a 100644 --- a/tests/unit_tests/api/service/test_bot_service.py +++ b/tests/unit_tests/api/service/test_bot_service.py @@ -303,13 +303,7 @@ class TestBotServiceCreateBot: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() ap.instance_config = SimpleNamespace() - ap.instance_config.data = { - 'system': { - 'limitation': { - 'max_bots': 2 - } - } - } + ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}} ap.platform_mgr = SimpleNamespace() ap.platform_mgr.load_bot = AsyncMock() @@ -318,9 +312,7 @@ class TestBotServiceCreateBot: bot2 = _create_mock_bot(bot_uuid='uuid-2') mock_result = _create_mock_result([bot1, bot2]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'uuid-1', 'name': 'Bot 1'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}) service = BotService(ap) @@ -352,6 +344,7 @@ class TestBotServiceCreateBot: bot_result.first = Mock(return_value=_create_mock_bot()) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -362,9 +355,7 @@ class TestBotServiceCreateBot: return bot_result # Get bot ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'new-uuid', 'name': 'New Bot'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'}) service = BotService(ap) @@ -397,6 +388,7 @@ class TestBotServiceCreateBot: bot_result.first = Mock(return_value=_create_mock_bot()) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -492,6 +484,7 @@ class TestBotServiceUpdateBot: pipeline_result.first = Mock(return_value=mock_pipeline) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -582,10 +575,9 @@ class TestBotServiceListEventLogs: # Mock runtime bot with logger runtime_bot = SimpleNamespace() runtime_bot.logger = SimpleNamespace() - runtime_bot.logger.get_logs = AsyncMock(return_value=( - [SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], - 5 - )) + runtime_bot.logger.get_logs = AsyncMock( + return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5) + ) ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot) service = BotService(ap) @@ -646,11 +638,7 @@ class TestBotServiceSendMessage: service = BotService(ap) # Execute with valid message chain format - message_chain_data = { - 'messages': [ - {'type': 'text', 'data': {'text': 'Hello'}} - ] - } + message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]} # Patch the import location - the module imports inside the function with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain: diff --git a/tests/unit_tests/api/service/test_knowledge_service.py b/tests/unit_tests/api/service/test_knowledge_service.py index 87aeddcff..1e0592b01 100644 --- a/tests/unit_tests/api/service/test_knowledge_service.py +++ b/tests/unit_tests/api/service/test_knowledge_service.py @@ -6,6 +6,7 @@ Tests cover: - Knowledge engine discovery - File operations """ + from __future__ import annotations import pytest @@ -52,9 +53,7 @@ class TestGetKnowledgeBases: """Test that it returns all knowledge base details.""" knowledge_module = get_knowledge_service_module() mock_app = create_mock_app() - mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock( - return_value=[{'uuid': 'kb1', 'name': 'KB1'}] - ) + mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}]) service = knowledge_module.KnowledgeService(mock_app) result = await service.get_knowledge_bases() @@ -83,9 +82,7 @@ class TestGetKnowledgeBase: """Test that it returns specific KB details.""" knowledge_module = get_knowledge_service_module() mock_app = create_mock_app() - mock_app.rag_mgr.get_knowledge_base_details = AsyncMock( - return_value={'uuid': 'kb1', 'name': 'KB1'} - ) + mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'}) service = knowledge_module.KnowledgeService(mock_app) result = await service.get_knowledge_base('kb1') @@ -153,9 +150,7 @@ class TestCreateKnowledgeBase: service = knowledge_module.KnowledgeService(mock_app) - await service.create_knowledge_base({ - 'knowledge_engine_plugin_id': 'author/engine' - }) + await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'}) # Check that default name 'Untitled' was used call_args = mock_app.rag_mgr.create_knowledge_base.call_args @@ -170,20 +165,21 @@ class TestUpdateKnowledgeBase: """Test that only mutable fields are updated.""" knowledge_module = get_knowledge_service_module() mock_app = create_mock_app() - mock_app.rag_mgr.get_knowledge_base_details = AsyncMock( - return_value={'uuid': 'kb1', 'name': 'Updated'} - ) + mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'}) mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock() mock_app.rag_mgr.load_knowledge_base = AsyncMock() service = knowledge_module.KnowledgeService(mock_app) # Pass both mutable and immutable fields - await service.update_knowledge_base('kb1', { - 'name': 'New Name', - 'description': 'New desc', - 'uuid': 'should_be_filtered', # immutable - }) + await service.update_knowledge_base( + 'kb1', + { + 'name': 'New Name', + 'description': 'New desc', + 'uuid': 'should_be_filtered', # immutable + }, + ) # Check that only mutable fields were passed to update call_args = mock_app.persistence_mgr.execute_async.call_args @@ -288,9 +284,7 @@ class TestListKnowledgeEngines: """Test that it returns empty list and logs warning on exception.""" knowledge_module = get_knowledge_service_module() mock_app = create_mock_app() - mock_app.plugin_connector.list_knowledge_engines = AsyncMock( - side_effect=Exception('Connection error') - ) + mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error')) service = knowledge_module.KnowledgeService(mock_app) result = await service.list_knowledge_engines() @@ -386,12 +380,10 @@ class TestGetEngineSchemas: """Test that it returns empty dict and logs warning on exception.""" knowledge_module = get_knowledge_service_module() mock_app = create_mock_app() - mock_app.plugin_connector.get_rag_creation_schema = AsyncMock( - side_effect=Exception('Plugin error') - ) + mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error')) service = knowledge_module.KnowledgeService(mock_app) result = await service.get_engine_creation_schema('author/engine') assert result == {} - mock_app.logger.warning.assert_called_once() \ No newline at end of file + mock_app.logger.warning.assert_called_once() diff --git a/tests/unit_tests/api/service/test_maintenance_service.py b/tests/unit_tests/api/service/test_maintenance_service.py index fcedf8b4e..8d5b5b0df 100644 --- a/tests/unit_tests/api/service/test_maintenance_service.py +++ b/tests/unit_tests/api/service/test_maintenance_service.py @@ -174,9 +174,7 @@ class TestMaintenanceServiceGetStorageAnalysis: # Setup ap = SimpleNamespace() ap.instance_config = SimpleNamespace() - ap.instance_config.data = { - 'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}} - } + ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}} ap.persistence_mgr = SimpleNamespace() ap.logger = SimpleNamespace() ap.logger.warning = Mock() @@ -292,12 +290,8 @@ class TestMaintenanceServiceGetStorageAnalysis: service._file_count = Mock(return_value=0) service._monitoring_counts = AsyncMock(return_value={}) service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0}) - service._expired_uploaded_candidates = AsyncMock(return_value=[ - {'key': 'old_file', 'size_bytes': 100} - ]) - service._expired_log_candidates = Mock(return_value=[ - {'name': 'old_log', 'size_bytes': 50} - ]) + service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}]) + service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}]) # Execute result = await service.get_storage_analysis() @@ -367,6 +361,7 @@ class TestMaintenanceServiceBinaryStorageStats: size_result = _create_mock_result(scalar_value=5000) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -396,6 +391,7 @@ class TestMaintenanceServiceBinaryStorageStats: count_result = _create_mock_result(scalar_value=5) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -821,4 +817,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates: result = service._expired_local_upload_candidates(7, include_paths=True) # Verify - path included - assert 'path' in result[0] \ No newline at end of file + assert 'path' in result[0] diff --git a/tests/unit_tests/api/service/test_mcp_service.py b/tests/unit_tests/api/service/test_mcp_service.py index 7f6ae83c6..17c746e73 100644 --- a/tests/unit_tests/api/service/test_mcp_service.py +++ b/tests/unit_tests/api/service/test_mcp_service.py @@ -186,13 +186,7 @@ class TestMCPServiceCreateMCPServer: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() ap.instance_config = SimpleNamespace() - ap.instance_config.data = { - 'system': { - 'limitation': { - 'max_extensions': 2 - } - } - } + ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}} ap.plugin_connector = SimpleNamespace() ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins @@ -252,6 +246,7 @@ class TestMCPServiceCreateMCPServer: server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -361,6 +356,7 @@ class TestMCPServiceUpdateMCPServer: old_server = _create_mock_mcp_server(name='Old Server', enable=True) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -394,6 +390,7 @@ class TestMCPServiceUpdateMCPServer: updated_server = _create_mock_mcp_server(name='Old Server', enable=True) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -432,6 +429,7 @@ class TestMCPServiceUpdateMCPServer: # Mock for: first select -> update -> second select (for updated server) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -465,6 +463,7 @@ class TestMCPServiceUpdateMCPServer: # Mock execute for select and update call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -499,6 +498,7 @@ class TestMCPServiceDeleteMCPServer: server = _create_mock_mcp_server(name='Server to Delete') call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -530,6 +530,7 @@ class TestMCPServiceDeleteMCPServer: server = _create_mock_mcp_server(name='Not in Sessions') call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -559,6 +560,7 @@ class TestMCPServiceDeleteMCPServer: # No server found call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -596,9 +598,7 @@ class TestMCPServiceTestMCPServer: ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session) ap.task_mgr = SimpleNamespace() - ap.task_mgr.create_user_task = Mock( - return_value=SimpleNamespace(id=123) - ) + ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123)) service = MCPService(ap) @@ -634,9 +634,7 @@ class TestMCPServiceTestMCPServer: ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session) ap.task_mgr = SimpleNamespace() - ap.task_mgr.create_user_task = Mock( - return_value=SimpleNamespace(id=456) - ) + ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456)) service = MCPService(ap) @@ -645,4 +643,4 @@ class TestMCPServiceTestMCPServer: # Verify - load_mcp_server called ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once() - assert task_id == 456 \ No newline at end of file + assert task_id == 456 diff --git a/tests/unit_tests/api/service/test_model_service.py b/tests/unit_tests/api/service/test_model_service.py index a0ffc92dd..42129ed3b 100644 --- a/tests/unit_tests/api/service/test_model_service.py +++ b/tests/unit_tests/api/service/test_model_service.py @@ -167,6 +167,7 @@ class TestLLMModelsServiceGetLLMModels: mock_provider_result = _create_mock_result([]) call_count = 0 + async def mock_execute(query): return mock_result if call_count == 0 else mock_provider_result @@ -200,6 +201,7 @@ class TestLLMModelsServiceGetLLMModels: mock_provider_result = _create_mock_result([provider]) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -239,6 +241,7 @@ class TestLLMModelsServiceGetLLMModels: mock_provider_result = _create_mock_result([provider]) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -279,6 +282,7 @@ class TestLLMModelsServiceGetLLMModel: mock_provider_result = _create_mock_result([], first_item=provider) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -337,9 +341,7 @@ class TestLLMModelsServiceGetLLMModelsByProvider: mock_result = _create_mock_result([model1, model2]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'model-1', 'name': 'Model 1'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'}) service = LLMModelsService(ap) @@ -371,12 +373,14 @@ class TestLLMModelsServiceCreateLLMModel: service = LLMModelsService(ap) # Execute - model_uuid = await service.create_llm_model({ - 'name': 'New LLM', - 'provider_uuid': 'provider-uuid', - 'abilities': [], - 'extra_args': {}, - }) + model_uuid = await service.create_llm_model( + { + 'name': 'New LLM', + 'provider_uuid': 'provider-uuid', + 'abilities': [], + 'extra_args': {}, + } + ) # Verify assert model_uuid is not None @@ -400,13 +404,16 @@ class TestLLMModelsServiceCreateLLMModel: service = LLMModelsService(ap) # Execute - model_uuid = await service.create_llm_model({ - 'uuid': 'preserved-uuid', - 'name': 'Preserved UUID Model', - 'provider_uuid': 'provider-uuid', - 'abilities': [], - 'extra_args': {}, - }, preserve_uuid=True) + model_uuid = await service.create_llm_model( + { + 'uuid': 'preserved-uuid', + 'name': 'Preserved UUID Model', + 'provider_uuid': 'provider-uuid', + 'abilities': [], + 'extra_args': {}, + }, + preserve_uuid=True, + ) # Verify assert model_uuid == 'preserved-uuid' @@ -459,12 +466,14 @@ class TestLLMModelsServiceCreateLLMModel: # Execute & Verify with pytest.raises(Exception, match='provider not found'): - await service.create_llm_model({ - 'name': 'No Provider Model', - 'provider_uuid': 'nonexistent-provider', - 'abilities': [], - 'extra_args': {}, - }) + await service.create_llm_model( + { + 'name': 'No Provider Model', + 'provider_uuid': 'nonexistent-provider', + 'abilities': [], + 'extra_args': {}, + } + ) async def test_create_llm_model_with_provider_data(self): """Creates provider when provider data provided.""" @@ -490,16 +499,18 @@ class TestLLMModelsServiceCreateLLMModel: service = LLMModelsService(ap) # Execute - with provider data (no UUID) - result_uuid = await service.create_llm_model({ - 'name': 'Model with New Provider', - 'provider': { - 'requester': 'openai', - 'base_url': 'https://api.openai.com', - 'api_keys': ['key'], - }, - 'abilities': [], - 'extra_args': {}, - }) + result_uuid = await service.create_llm_model( + { + 'name': 'Model with New Provider', + 'provider': { + 'requester': 'openai', + 'base_url': 'https://api.openai.com', + 'api_keys': ['key'], + }, + 'abilities': [], + 'extra_args': {}, + } + ) # Verify - provider_service was called and UUID generated ap.provider_service.find_or_create_provider.assert_called_once() @@ -525,11 +536,14 @@ class TestLLMModelsServiceUpdateLLMModel: service = LLMModelsService(ap) # Execute - await service.update_llm_model('existing-uuid', { - 'uuid': 'should-be-removed', - 'name': 'Updated Name', - 'provider_uuid': 'provider-uuid', - }) + await service.update_llm_model( + 'existing-uuid', + { + 'uuid': 'should-be-removed', + 'name': 'Updated Name', + 'provider_uuid': 'provider-uuid', + }, + ) # Verify - remove and load called ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid') @@ -549,10 +563,13 @@ class TestLLMModelsServiceUpdateLLMModel: # Execute & Verify with pytest.raises(Exception, match='provider not found'): - await service.update_llm_model('model-uuid', { - 'name': 'Update', - 'provider_uuid': 'nonexistent-provider', - }) + await service.update_llm_model( + 'model-uuid', + { + 'name': 'Update', + 'provider_uuid': 'nonexistent-provider', + }, + ) async def test_update_llm_model_reloads_context_length_as_column(self): """Updates runtime model with context_length outside extra_args.""" @@ -618,9 +635,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels: mock_result = _create_mock_result([]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'embedding-uuid', 'name': 'Test'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'}) service = EmbeddingModelsService(ap) @@ -643,6 +658,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels: mock_provider_result = _create_mock_result([provider]) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -683,6 +699,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModel: mock_provider_result = _create_mock_result([], first_item=provider) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -742,11 +759,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel: service = EmbeddingModelsService(ap) # Execute - model_uuid = await service.create_embedding_model({ - 'name': 'New Embedding', - 'provider_uuid': 'provider-uuid', - 'extra_args': {}, - }) + model_uuid = await service.create_embedding_model( + { + 'name': 'New Embedding', + 'provider_uuid': 'provider-uuid', + 'extra_args': {}, + } + ) # Verify assert model_uuid is not None @@ -767,11 +786,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel: # Execute & Verify with pytest.raises(Exception, match='provider not found'): - await service.create_embedding_model({ - 'name': 'No Provider Embedding', - 'provider_uuid': 'nonexistent', - 'extra_args': {}, - }) + await service.create_embedding_model( + { + 'name': 'No Provider Embedding', + 'provider_uuid': 'nonexistent', + 'extra_args': {}, + } + ) class TestEmbeddingModelsServiceDeleteEmbeddingModel: @@ -829,6 +850,7 @@ class TestRerankModelsServiceGetRerankModels: mock_provider_result = _create_mock_result([provider]) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -869,6 +891,7 @@ class TestRerankModelsServiceGetRerankModel: mock_provider_result = _create_mock_result([], first_item=provider) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -928,11 +951,13 @@ class TestRerankModelsServiceCreateRerankModel: service = RerankModelsService(ap) # Execute - model_uuid = await service.create_rerank_model({ - 'name': 'New Rerank', - 'provider_uuid': 'provider-uuid', - 'extra_args': {}, - }) + model_uuid = await service.create_rerank_model( + { + 'name': 'New Rerank', + 'provider_uuid': 'provider-uuid', + 'extra_args': {}, + } + ) # Verify assert model_uuid is not None @@ -952,11 +977,13 @@ class TestRerankModelsServiceCreateRerankModel: # Execute & Verify with pytest.raises(Exception, match='provider not found'): - await service.create_rerank_model({ - 'name': 'No Provider Rerank', - 'provider_uuid': 'nonexistent', - 'extra_args': {}, - }) + await service.create_rerank_model( + { + 'name': 'No Provider Rerank', + 'provider_uuid': 'nonexistent', + 'extra_args': {}, + } + ) class TestRerankModelsServiceDeleteRerankModel: @@ -995,9 +1022,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider: mock_result = _create_mock_result([model1, model2]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'emb-1', 'name': 'Embedding 1'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}) service = EmbeddingModelsService(ap) @@ -1022,9 +1047,7 @@ class TestRerankModelsServiceGetRerankModelsByProvider: mock_result = _create_mock_result([model1, model2]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}) service = RerankModelsService(ap) @@ -1042,14 +1065,10 @@ class TestValidateProviderSupports: def _make_ap(requester_name: str, support_type): """Build a fake ap whose model_mgr resolves a manifest with support_type.""" manifest = SimpleNamespace(spec={'support_type': support_type}) - runtime_provider = SimpleNamespace( - provider_entity=SimpleNamespace(requester=requester_name) - ) + runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name)) model_mgr = SimpleNamespace( provider_dict={'p1': runtime_provider}, - get_available_requester_manifest_by_name=lambda name: manifest - if name == requester_name - else None, + get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None, ) return SimpleNamespace(model_mgr=model_mgr) @@ -1066,9 +1085,7 @@ class TestValidateProviderSupports: async def test_allows_when_support_type_missing(self): # Manifest without support_type must not block (backward compatible) manifest = SimpleNamespace(spec={}) - runtime_provider = SimpleNamespace( - provider_entity=SimpleNamespace(requester='legacy') - ) + runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy')) model_mgr = SimpleNamespace( provider_dict={'p1': runtime_provider}, get_available_requester_manifest_by_name=lambda name: manifest, diff --git a/tests/unit_tests/api/service/test_pipeline_service.py b/tests/unit_tests/api/service/test_pipeline_service.py index a84adab8f..28d2fc117 100644 --- a/tests/unit_tests/api/service/test_pipeline_service.py +++ b/tests/unit_tests/api/service/test_pipeline_service.py @@ -215,13 +215,7 @@ class TestPipelineServiceCreatePipeline: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() ap.instance_config = SimpleNamespace() - ap.instance_config.data = { - 'system': { - 'limitation': { - 'max_pipelines': 2 - } - } - } + ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}} ap.pipeline_mgr = SimpleNamespace() ap.pipeline_mgr.load_pipeline = AsyncMock() ap.ver_mgr = SimpleNamespace() @@ -229,9 +223,7 @@ class TestPipelineServiceCreatePipeline: mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()]) ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}) service = PipelineService(ap) @@ -258,14 +250,14 @@ class TestPipelineServiceCreatePipeline: # Mock persistence for insert ap.persistence_mgr.execute_async = AsyncMock() - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}) # Mock the file read for default config - patch at the utils module level default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}} with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): - with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'): + with patch( + 'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json' + ): bot_uuid = await service.create_pipeline({'name': 'New Pipeline'}) # Verify @@ -286,7 +278,9 @@ class TestPipelineServiceCreatePipeline: service = PipelineService(ap) service.get_pipelines = AsyncMock(return_value=[]) - service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}) + service.get_pipeline = AsyncMock( + return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True} + ) ap.persistence_mgr.execute_async = AsyncMock() ap.persistence_mgr.serialize_model = Mock( @@ -296,7 +290,9 @@ class TestPipelineServiceCreatePipeline: # Mock the file read default_config = {} with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): - with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'): + with patch( + 'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json' + ): await service.create_pipeline({'name': 'Default Pipeline'}, default=True) # Verify - execute was called @@ -316,10 +312,12 @@ class TestPipelineServiceCreatePipeline: service = PipelineService(ap) service.get_pipelines = AsyncMock(return_value=[]) - service.get_pipeline = AsyncMock(return_value={ - 'uuid': 'new-uuid', - 'extensions_preferences': {}, - }) + service.get_pipeline = AsyncMock( + return_value={ + 'uuid': 'new-uuid', + 'extensions_preferences': {}, + } + ) insert_params = [] @@ -339,7 +337,9 @@ class TestPipelineServiceCreatePipeline: default_config = {} with patch('builtins.open', mock_open(read_data=json.dumps(default_config))): - with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'): + with patch( + 'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json' + ): await service.create_pipeline({'name': 'New Pipeline'}) assert len(insert_params) == 1 @@ -353,6 +353,7 @@ class TestPipelineServiceCreatePipeline: class _MockResultWithBots: """Helper class to mock SQLAlchemy result with iterable .all() method.""" + def __init__(self, bots_list): self._bots_list = bots_list @@ -428,6 +429,7 @@ class TestPipelineServiceUpdatePipeline: # 1. UPDATE (line 125) - returns Mock (no result needed) # 2. SELECT bots (line 136) - returns bot_result with .all() call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -528,13 +530,7 @@ class TestPipelineServiceCopyPipeline: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() ap.instance_config = SimpleNamespace() - ap.instance_config.data = { - 'system': { - 'limitation': { - 'max_pipelines': 2 - } - } - } + ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}} ap.pipeline_mgr = SimpleNamespace() ap.pipeline_mgr.load_pipeline = AsyncMock() ap.ver_mgr = SimpleNamespace() @@ -542,10 +538,12 @@ class TestPipelineServiceCopyPipeline: service = PipelineService(ap) # Mock get_pipelines to return 2 pipelines - service.get_pipelines = AsyncMock(return_value=[ - {'uuid': 'uuid-1', 'name': 'Pipeline 1'}, - {'uuid': 'uuid-2', 'name': 'Pipeline 2'}, - ]) + service.get_pipelines = AsyncMock( + return_value=[ + {'uuid': 'uuid-1', 'name': 'Pipeline 1'}, + {'uuid': 'uuid-2', 'name': 'Pipeline 2'}, + ] + ) # Execute & Verify with pytest.raises(ValueError, match='Maximum number of pipelines'): @@ -642,9 +640,7 @@ class TestPipelineServiceCopyPipeline: service = PipelineService(ap) service.get_pipelines = AsyncMock(return_value=[]) ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original)) - ap.persistence_mgr.serialize_model = Mock( - return_value={'uuid': 'copy-uuid', 'is_default': False} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False}) service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False}) @@ -681,11 +677,10 @@ class TestPipelineServiceUpdatePipelineExtensions: ap.pipeline_mgr.remove_pipeline = AsyncMock() ap.pipeline_mgr.load_pipeline = AsyncMock() - original_pipeline = _create_mock_pipeline( - extensions_preferences={'enable_all_plugins': True, 'plugins': []} - ) + original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []}) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -700,7 +695,7 @@ class TestPipelineServiceUpdatePipelineExtensions: 'extensions_preferences': { 'enable_all_plugins': False, 'plugins': [{'plugin_uuid': 'plugin-1'}], - } + }, } ) @@ -711,7 +706,7 @@ class TestPipelineServiceUpdatePipelineExtensions: 'extensions_preferences': { 'enable_all_plugins': False, 'plugins': [{'plugin_uuid': 'plugin-1'}], - } + }, } ) @@ -738,6 +733,7 @@ class TestPipelineServiceUpdatePipelineExtensions: original_pipeline = _create_mock_pipeline() call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -752,7 +748,7 @@ class TestPipelineServiceUpdatePipelineExtensions: 'extensions_preferences': { 'enable_all_mcp_servers': False, 'mcp_servers': ['mcp-server-1'], - } + }, } ) @@ -794,6 +790,7 @@ class TestPipelineServiceUpdatePipelineExtensions: ) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 diff --git a/tests/unit_tests/api/service/test_provider_service.py b/tests/unit_tests/api/service/test_provider_service.py index 4c3f818d4..8b308af8d 100644 --- a/tests/unit_tests/api/service/test_provider_service.py +++ b/tests/unit_tests/api/service/test_provider_service.py @@ -245,12 +245,14 @@ class TestModelProviderServiceCreateProvider: service = ModelProviderService(ap) # Execute - provider_uuid = await service.create_provider({ - 'name': 'New Provider', - 'requester': 'openai', - 'base_url': 'https://api.openai.com', - 'api_keys': ['key'], - }) + provider_uuid = await service.create_provider( + { + 'name': 'New Provider', + 'requester': 'openai', + 'base_url': 'https://api.openai.com', + 'api_keys': ['key'], + } + ) # Verify - UUID is generated assert provider_uuid is not None @@ -274,12 +276,14 @@ class TestModelProviderServiceCreateProvider: service = ModelProviderService(ap) # Execute - result_uuid = await service.create_provider({ - 'name': 'Runtime Provider', - 'requester': 'openai', - 'base_url': 'https://api.openai.com', - 'api_keys': ['key'], - }) + result_uuid = await service.create_provider( + { + 'name': 'Runtime Provider', + 'requester': 'openai', + 'base_url': 'https://api.openai.com', + 'api_keys': ['key'], + } + ) # Verify - provider added to runtime dict and UUID generated ap.model_mgr.load_provider.assert_called_once() @@ -302,10 +306,13 @@ class TestModelProviderServiceUpdateProvider: service = ModelProviderService(ap) # Execute - await service.update_provider('existing-uuid', { - 'uuid': 'should-be-removed', # Will be removed - 'name': 'Updated Name', - }) + await service.update_provider( + 'existing-uuid', + { + 'uuid': 'should-be-removed', # Will be removed + 'name': 'Updated Name', + }, + ) # Verify - reload called ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid') @@ -364,6 +371,7 @@ class TestModelProviderServiceDeleteProvider: rerank_result.first = Mock(return_value=None) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -396,6 +404,7 @@ class TestModelProviderServiceDeleteProvider: rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -454,6 +463,7 @@ class TestModelProviderServiceGetProviderModelCounts: rerank_result.scalar = Mock(return_value=1) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -637,9 +647,7 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys: await service.update_space_model_provider_api_keys('space-api-key') # Verify - update and reload called for Space provider UUID - ap.model_mgr.reload_provider.assert_called_once_with( - '00000000-0000-0000-0000-000000000000' - ) + ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000') class TestModelProviderServiceScanProviderModels: @@ -795,9 +803,7 @@ class TestModelProviderServiceScanProviderModels: runtime_provider.token_mgr = Mock() runtime_provider.token_mgr.get_token = Mock(return_value='token') runtime_provider.token_mgr.tokens = ['token'] - runtime_provider.requester.scan_models = AsyncMock( - side_effect=NotImplementedError('scan not supported') - ) + runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported')) ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider) service = ModelProviderService(ap) @@ -848,9 +854,7 @@ class TestModelProviderServiceScanProviderModels: ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider) # Mock existing LLM model - ap.llm_model_service.get_llm_models_by_provider = AsyncMock( - return_value=[{'name': 'Existing Model'}] - ) + ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}]) ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[]) service = ModelProviderService(ap) @@ -863,4 +867,4 @@ class TestModelProviderServiceScanProviderModels: assert existing_model['already_added'] is True new_model = next(m for m in result['models'] if m['name'] == 'New Model') - assert new_model['already_added'] is False \ No newline at end of file + assert new_model['already_added'] is False diff --git a/tests/unit_tests/api/service/test_space_service.py b/tests/unit_tests/api/service/test_space_service.py index 968753133..f48b18937 100644 --- a/tests/unit_tests/api/service/test_space_service.py +++ b/tests/unit_tests/api/service/test_space_service.py @@ -393,14 +393,16 @@ class TestSpaceServiceRefreshToken: # Mock HTTP response mock_response = MagicMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - 'code': 0, - 'data': { - 'access_token': 'new_access_token', - 'refresh_token': 'new_refresh_token', - 'expires_in': 3600, + mock_response.json = AsyncMock( + return_value={ + 'code': 0, + 'data': { + 'access_token': 'new_access_token', + 'refresh_token': 'new_refresh_token', + 'expires_in': 3600, + }, } - }) + ) with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: mock_session_obj = MagicMock() @@ -429,10 +431,12 @@ class TestSpaceServiceRefreshToken: # Mock HTTP response with error mock_response = MagicMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - 'code': 1, - 'msg': 'Invalid refresh token', - }) + mock_response.json = AsyncMock( + return_value={ + 'code': 1, + 'msg': 'Invalid refresh token', + } + ) mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}') with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: @@ -489,14 +493,16 @@ class TestSpaceServiceExchangeOAuthCode: # Mock HTTP response mock_response = MagicMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - 'code': 0, - 'data': { - 'access_token': 'new_access_token', - 'refresh_token': 'new_refresh_token', - 'expires_in': 3600, + mock_response.json = AsyncMock( + return_value={ + 'code': 0, + 'data': { + 'access_token': 'new_access_token', + 'refresh_token': 'new_refresh_token', + 'expires_in': 3600, + }, } - }) + ) with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: mock_session_obj = MagicMock() @@ -555,13 +561,15 @@ class TestSpaceServiceGetUserInfoRaw: # Mock HTTP response mock_response = MagicMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - 'code': 0, - 'data': { - 'email': 'test@example.com', - 'credits': 100, + mock_response.json = AsyncMock( + return_value={ + 'code': 0, + 'data': { + 'email': 'test@example.com', + 'credits': 100, + }, } - }) + ) with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: mock_session_obj = MagicMock() @@ -669,27 +677,29 @@ class TestSpaceServiceGetModels: # Mock HTTP response with proper model data matching SpaceModel schema mock_response = MagicMock() mock_response.status = 200 - mock_response.json = AsyncMock(return_value={ - 'code': 0, - 'data': { - 'models': [ - { - 'uuid': 'uuid-1', - 'model_id': 'model-1', - 'provider': 'provider-1', - 'category': 'chat', - 'status': 'active', - }, - { - 'uuid': 'uuid-2', - 'model_id': 'model-2', - 'provider': 'provider-2', - 'category': 'chat', - 'status': 'active', - }, - ] + mock_response.json = AsyncMock( + return_value={ + 'code': 0, + 'data': { + 'models': [ + { + 'uuid': 'uuid-1', + 'model_id': 'model-1', + 'provider': 'provider-1', + 'category': 'chat', + 'status': 'active', + }, + { + 'uuid': 'uuid-2', + 'model_id': 'model-2', + 'provider': 'provider-2', + 'category': 'chat', + 'status': 'active', + }, + ] + }, } - }) + ) with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session: mock_session_obj = MagicMock() @@ -775,4 +785,4 @@ class TestSpaceServiceCreditsCache: # Verify - cache updated assert result == 500 assert 'test@example.com' in service._credits_cache - assert service._credits_cache['test@example.com'][0] == 500 \ No newline at end of file + assert service._credits_cache['test@example.com'][0] == 500 diff --git a/tests/unit_tests/api/service/test_user_service.py b/tests/unit_tests/api/service/test_user_service.py index 54d0674e0..c5d37f167 100644 --- a/tests/unit_tests/api/service/test_user_service.py +++ b/tests/unit_tests/api/service/test_user_service.py @@ -495,6 +495,7 @@ class TestUserServiceCreateOrUpdateSpaceUser: # First call (line 138) returns None, second call (line 194) returns new_user call_count = 0 + async def mock_get_by_space_uuid(uuid): nonlocal call_count call_count += 1 @@ -565,6 +566,7 @@ class TestUserServiceCreateOrUpdateSpaceUser: # First call (line 138) returns None, second call (line 194) returns new_user call_count = 0 + async def mock_get_by_space_uuid(uuid): nonlocal call_count call_count += 1 @@ -605,4 +607,4 @@ class TestUserServiceCreateUserLock: # Verify lock exists assert hasattr(service, '_create_user_lock') - assert service._create_user_lock is not None \ No newline at end of file + assert service._create_user_lock is not None diff --git a/tests/unit_tests/api/service/test_webhook_service.py b/tests/unit_tests/api/service/test_webhook_service.py index ef2469c1e..7a5a075ef 100644 --- a/tests/unit_tests/api/service/test_webhook_service.py +++ b/tests/unit_tests/api/service/test_webhook_service.py @@ -132,6 +132,7 @@ class TestWebhookServiceCreateWebhook: # execute_async returns different results call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -181,6 +182,7 @@ class TestWebhookServiceCreateWebhook: ) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -217,6 +219,7 @@ class TestWebhookServiceCreateWebhook: created_webhook = _create_mock_webhook(webhook_id=1, enabled=False) call_count = 0 + async def mock_execute(query): nonlocal call_count call_count += 1 @@ -225,9 +228,7 @@ class TestWebhookServiceCreateWebhook: return _create_mock_result(first_item=created_webhook) ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) - ap.persistence_mgr.serialize_model = Mock( - return_value={'id': 1, 'enabled': False} - ) + ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False}) service = WebhookService(ap) @@ -503,4 +504,4 @@ class TestWebhookServiceGetEnabledWebhooks: result = await service.get_enabled_webhooks() # Verify - should be empty (SQL would filter disabled) - assert result == [] \ No newline at end of file + assert result == [] diff --git a/tests/unit_tests/box/test_box_service.py b/tests/unit_tests/box/test_box_service.py index 4e947653d..c59a1c5e9 100644 --- a/tests/unit_tests/box/test_box_service.py +++ b/tests/unit_tests/box/test_box_service.py @@ -407,7 +407,9 @@ def test_box_service_forced_template_ignores_pipeline_config(): launcher_type='person', launcher_id='test_user', sender_id='test_user', - pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}}, + pipeline_config={ + 'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}} + }, ) assert service.resolve_box_session_id(query) == 'global' @@ -1527,9 +1529,7 @@ class TestBuildSkillExtraMounts: {'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'}, ] # No skill is dropped, so no "missing" warning should be logged. - assert not any( - 'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list - ) + assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list) def test_skips_skill_with_empty_package_root(self): logger = Mock() diff --git a/tests/unit_tests/command/__init__.py b/tests/unit_tests/command/__init__.py index 97081441c..99c57c405 100644 --- a/tests/unit_tests/command/__init__.py +++ b/tests/unit_tests/command/__init__.py @@ -1 +1 @@ -# Unit tests for command module \ No newline at end of file +# Unit tests for command module diff --git a/tests/unit_tests/command/test_cmdmgr.py b/tests/unit_tests/command/test_cmdmgr.py index 067eb7e43..ade27cf48 100644 --- a/tests/unit_tests/command/test_cmdmgr.py +++ b/tests/unit_tests/command/test_cmdmgr.py @@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs: # Should yield CommandNotFoundError (no such command registered) assert len(results) == 1 - assert results[0].error is not None \ No newline at end of file + assert results[0].error is not None diff --git a/tests/unit_tests/command/test_operator.py b/tests/unit_tests/command/test_operator.py index d099c7af8..a1d345292 100644 --- a/tests/unit_tests/command/test_operator.py +++ b/tests/unit_tests/command/test_operator.py @@ -197,6 +197,7 @@ class TestCommandOperatorBase: op = TestOperator(None) # Should not raise import asyncio + asyncio.get_event_loop().run_until_complete(op.initialize()) def test_execute_is_abstract(self): @@ -299,4 +300,4 @@ class TestMultipleOperators: yield None assert AdminOperator.lowest_privilege == 2 - assert SubOperator.lowest_privilege == 1 \ No newline at end of file + assert SubOperator.lowest_privilege == 1 diff --git a/tests/unit_tests/config/test_config_loader.py b/tests/unit_tests/config/test_config_loader.py index f228bf441..ec9942076 100644 --- a/tests/unit_tests/config/test_config_loader.py +++ b/tests/unit_tests/config/test_config_loader.py @@ -25,7 +25,7 @@ class TestYAMLConfigFile: @pytest.mark.asyncio async def test_valid_yaml_loads(self, tmp_path): """Valid YAML config should load correctly.""" - config_file = tmp_path / "test_config.yaml" + config_file = tmp_path / 'test_config.yaml' # Write valid YAML config_file.write_text(""" @@ -51,7 +51,7 @@ settings: @pytest.mark.asyncio async def test_invalid_yaml_raises_error(self, tmp_path): """Invalid YAML should raise clear error.""" - config_file = tmp_path / "invalid.yaml" + config_file = tmp_path / 'invalid.yaml' # Write invalid YAML (unclosed bracket) config_file.write_text(""" @@ -67,13 +67,13 @@ settings: template_data={'name': 'default'}, ) - with pytest.raises(Exception, match="Syntax error"): + with pytest.raises(Exception, match='Syntax error'): await yaml_file.load(completion=False) @pytest.mark.asyncio async def test_missing_config_creates_from_template(self, tmp_path): """Missing config file should be created from template.""" - config_file = tmp_path / "new_config.yaml" + config_file = tmp_path / 'new_config.yaml' # File doesn't exist yet assert not config_file.exists() @@ -92,7 +92,7 @@ settings: @pytest.mark.asyncio async def test_template_completion(self, tmp_path): """Config should be completed with template defaults.""" - config_file = tmp_path / "partial.yaml" + config_file = tmp_path / 'partial.yaml' # Write partial config missing some template keys config_file.write_text(""" @@ -115,7 +115,7 @@ name: custom_name @pytest.mark.asyncio async def test_yaml_save(self, tmp_path): """YAML config can be saved.""" - config_file = tmp_path / "save_test.yaml" + config_file = tmp_path / 'save_test.yaml' yaml_file = YAMLConfigFile( str(config_file), @@ -131,7 +131,7 @@ name: custom_name def test_yaml_save_sync(self, tmp_path): """YAML config can be saved synchronously.""" - config_file = tmp_path / "sync_save.yaml" + config_file = tmp_path / 'sync_save.yaml' yaml_file = YAMLConfigFile( str(config_file), @@ -151,14 +151,18 @@ class TestJSONConfigFile: @pytest.mark.asyncio async def test_valid_json_loads(self, tmp_path): """Valid JSON config should load correctly.""" - config_file = tmp_path / "test_config.json" + config_file = tmp_path / 'test_config.json' # Write valid JSON - config_file.write_text(json.dumps({ - 'name': 'json_app', - 'version': '1.0', - 'settings': {'debug': True, 'port': 8080}, - })) + config_file.write_text( + json.dumps( + { + 'name': 'json_app', + 'version': '1.0', + 'settings': {'debug': True, 'port': 8080}, + } + ) + ) json_file = JSONConfigFile( str(config_file), @@ -174,7 +178,7 @@ class TestJSONConfigFile: @pytest.mark.asyncio async def test_invalid_json_raises_error(self, tmp_path): """Invalid JSON should raise clear error.""" - config_file = tmp_path / "invalid.json" + config_file = tmp_path / 'invalid.json' # Write invalid JSON (missing closing brace) config_file.write_text('{"name": "test", "unclosed": ') @@ -184,13 +188,13 @@ class TestJSONConfigFile: template_data={'name': 'default'}, ) - with pytest.raises(Exception, match="Syntax error"): + with pytest.raises(Exception, match='Syntax error'): await json_file.load(completion=False) @pytest.mark.asyncio async def test_missing_json_creates_from_template(self, tmp_path): """Missing JSON file should be created from template.""" - config_file = tmp_path / "new_config.json" + config_file = tmp_path / 'new_config.json' json_file = JSONConfigFile( str(config_file), @@ -205,7 +209,7 @@ class TestJSONConfigFile: @pytest.mark.asyncio async def test_json_save(self, tmp_path): """JSON config can be saved.""" - config_file = tmp_path / "save_test.json" + config_file = tmp_path / 'save_test.json' json_file = JSONConfigFile( str(config_file), @@ -226,7 +230,7 @@ class TestConfigManager: @pytest.mark.asyncio async def test_config_manager_load(self, tmp_path): """ConfigManager loads config correctly.""" - config_file = tmp_path / "manager_test.yaml" + config_file = tmp_path / 'manager_test.yaml' config_file.write_text('name: managed_app\nversion: "1.0"\n') yaml_file = YAMLConfigFile( @@ -243,7 +247,7 @@ class TestConfigManager: @pytest.mark.asyncio async def test_config_manager_dump(self, tmp_path): """ConfigManager can dump config.""" - config_file = tmp_path / "dump_test.yaml" + config_file = tmp_path / 'dump_test.yaml' yaml_file = YAMLConfigFile( str(config_file), @@ -260,7 +264,7 @@ class TestConfigManager: def test_config_manager_dump_sync(self, tmp_path): """ConfigManager can dump config synchronously.""" - config_file = tmp_path / "sync_dump.yaml" + config_file = tmp_path / 'sync_dump.yaml' yaml_file = YAMLConfigFile( str(config_file), @@ -280,7 +284,7 @@ class TestConfigExists: def test_yaml_exists_true(self, tmp_path): """exists() returns True for existing file.""" - config_file = tmp_path / "exists.yaml" + config_file = tmp_path / 'exists.yaml' config_file.write_text('name: test') yaml_file = YAMLConfigFile(str(config_file), template_data={}) @@ -288,14 +292,14 @@ class TestConfigExists: def test_yaml_exists_false(self, tmp_path): """exists() returns False for missing file.""" - config_file = tmp_path / "missing.yaml" + config_file = tmp_path / 'missing.yaml' yaml_file = YAMLConfigFile(str(config_file), template_data={}) assert yaml_file.exists() is False def test_json_exists_true(self, tmp_path): """exists() returns True for existing JSON file.""" - config_file = tmp_path / "exists.json" + config_file = tmp_path / 'exists.json' config_file.write_text('{}') json_file = JSONConfigFile(str(config_file), template_data={}) @@ -303,7 +307,7 @@ class TestConfigExists: def test_json_exists_false(self, tmp_path): """exists() returns False for missing JSON file.""" - config_file = tmp_path / "missing.json" + config_file = tmp_path / 'missing.json' json_file = JSONConfigFile(str(config_file), template_data={}) - assert json_file.exists() is False \ No newline at end of file + assert json_file.exists() is False diff --git a/tests/unit_tests/core/__init__.py b/tests/unit_tests/core/__init__.py index c02aca956..1b8a74025 100644 --- a/tests/unit_tests/core/__init__.py +++ b/tests/unit_tests/core/__init__.py @@ -1 +1 @@ -"""Core module unit tests.""" \ No newline at end of file +"""Core module unit tests.""" diff --git a/tests/unit_tests/core/test_app_config_validation.py b/tests/unit_tests/core/test_app_config_validation.py index b90a3bd75..ddf9721da 100644 --- a/tests/unit_tests/core/test_app_config_validation.py +++ b/tests/unit_tests/core/test_app_config_validation.py @@ -4,6 +4,7 @@ Tests cover: - _get_positive_int_config() validation - _get_positive_float_config() validation """ + from __future__ import annotations from unittest.mock import Mock @@ -188,4 +189,4 @@ class TestGetPositiveFloatConfig: result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config') assert result == 1.5 - mock_logger.warning.assert_called_once() \ No newline at end of file + mock_logger.warning.assert_called_once() diff --git a/tests/unit_tests/core/test_bootutils_deps.py b/tests/unit_tests/core/test_bootutils_deps.py index 35e928b94..c57baaf4b 100644 --- a/tests/unit_tests/core/test_bootutils_deps.py +++ b/tests/unit_tests/core/test_bootutils_deps.py @@ -27,6 +27,7 @@ class TestCheckDeps: from langbot.pkg.core.bootutils.deps import check_deps import asyncio + result = asyncio.get_event_loop().run_until_complete(check_deps()) assert result == [] @@ -46,6 +47,7 @@ class TestCheckDeps: from langbot.pkg.core.bootutils.deps import check_deps import asyncio + result = asyncio.get_event_loop().run_until_complete(check_deps()) assert 'requests' in result @@ -61,6 +63,7 @@ class TestCheckDeps: from langbot.pkg.core.bootutils.deps import check_deps, required_deps import asyncio + result = asyncio.get_event_loop().run_until_complete(check_deps()) # Should include all required_deps keys @@ -107,6 +110,7 @@ class TestPrecheckPluginDeps: with patch('os.path.exists', return_value=False): with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install: import asyncio + asyncio.get_event_loop().run_until_complete(precheck_plugin_deps()) mock_install.assert_not_called() @@ -129,6 +133,7 @@ class TestPrecheckPluginDeps: with patch('os.listdir', side_effect=mock_listdir): with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install: import asyncio + asyncio.get_event_loop().run_until_complete(precheck_plugin_deps()) mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[]) diff --git a/tests/unit_tests/core/test_load_config.py b/tests/unit_tests/core/test_load_config.py index 839a330f4..9a83f2fb6 100644 --- a/tests/unit_tests/core/test_load_config.py +++ b/tests/unit_tests/core/test_load_config.py @@ -7,6 +7,7 @@ Tests cover: - Dict type skipping - Missing key creation """ + from __future__ import annotations import os @@ -248,15 +249,8 @@ class TestApplyEnvOverridesToConfig: """Test multiple env vars applied in order.""" load_config = get_load_config_module() - cfg = { - 'system': {'name': 'default', 'enable': True}, - 'concurrency': {'pipeline': 5} - } - env = { - 'SYSTEM__NAME': 'custom', - 'SYSTEM__ENABLE': 'false', - 'CONCURRENCY__PIPELINE': '10' - } + cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}} + env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'} with patch.dict(os.environ, env, clear=True): result = load_config._apply_env_overrides_to_config(cfg) @@ -287,4 +281,4 @@ class TestApplyEnvOverridesToConfig: with patch.dict(os.environ, env, clear=True): result = load_config._apply_env_overrides_to_config(cfg) - assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com' \ No newline at end of file + assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com' diff --git a/tests/unit_tests/core/test_stage.py b/tests/unit_tests/core/test_stage.py index e09cbd310..e9b29e76b 100644 --- a/tests/unit_tests/core/test_stage.py +++ b/tests/unit_tests/core/test_stage.py @@ -175,4 +175,4 @@ class TestPreregisteredStages: pass for key in preregistered_stages: - assert isinstance(key, str) \ No newline at end of file + assert isinstance(key, str) diff --git a/tests/unit_tests/core/test_taskmgr.py b/tests/unit_tests/core/test_taskmgr.py index ca05724da..6c0d9828f 100644 --- a/tests/unit_tests/core/test_taskmgr.py +++ b/tests/unit_tests/core/test_taskmgr.py @@ -7,6 +7,7 @@ Tests cover: Note: Uses import_isolation to break circular import chains. """ + from __future__ import annotations import pytest @@ -19,15 +20,17 @@ from typing import Generator class MockLifecycleControlScopeEnum: """Mock enum value for LifecycleControlScope with .value attribute.""" + def __init__(self, value: str): self.value = value def __repr__(self): - return f"LifecycleControlScope.{self.value.upper()}" + return f'LifecycleControlScope.{self.value.upper()}' class MockLifecycleControlScope: """Mock enum for LifecycleControlScope.""" + APPLICATION = MockLifecycleControlScopeEnum('application') PLATFORM = MockLifecycleControlScopeEnum('platform') PIPELINE = MockLifecycleControlScopeEnum('pipeline') @@ -40,17 +43,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]: # Mock modules that cause circular imports mock_entities = MagicMock() mock_entities.LifecycleControlScope = MockLifecycleControlScope - + mock_app = MagicMock() - + mock_importutil = MagicMock() mock_importutil.import_modules_in_pkg = lambda pkg: None mock_importutil.import_modules_in_pkgs = lambda pkgs: None - + mock_http_controller = MagicMock() - + mock_rag_mgr = MagicMock() - + mocks = { 'langbot.pkg.core.entities': mock_entities, 'langbot.pkg.core.app': mock_app, @@ -58,26 +61,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]: 'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr, 'langbot.pkg.utils.importutil': mock_importutil, } - + # Save original state saved = {} for name in mocks: if name in sys.modules: saved[name] = sys.modules[name] - + # Clear taskmgr to force re-import taskmgr_name = 'langbot.pkg.core.taskmgr' if taskmgr_name in sys.modules: saved[taskmgr_name] = sys.modules[taskmgr_name] - + try: # Apply mocks for name, module in mocks.items(): sys.modules[name] = module - + # Clear taskmgr sys.modules.pop(taskmgr_name, None) - + yield finally: # Restore @@ -86,7 +89,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]: sys.modules[name] = saved[name] else: sys.modules.pop(name, None) - + if taskmgr_name in saved: sys.modules[taskmgr_name] = saved[taskmgr_name] else: @@ -97,6 +100,7 @@ def get_taskmgr_classes(): """Get TaskContext, TaskWrapper, AsyncTaskManager classes.""" with isolated_taskmgr_import(): from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager + return TaskContext, TaskWrapper, AsyncTaskManager @@ -194,9 +198,10 @@ class TestTaskContext: """Test TaskContext.placeholder() returns singleton.""" with isolated_taskmgr_import(): from langbot.pkg.core.taskmgr import TaskContext - + # Reset global placeholder import langbot.pkg.core.taskmgr as taskmgr_module + taskmgr_module.placeholder_context = None ctx1 = TaskContext.placeholder() @@ -269,7 +274,8 @@ class TestTaskWrapper: return 'result' wrapper = TaskWrapper( - mock_app, immediate_coro(), + mock_app, + immediate_coro(), name='test_task', label='Test Task', ) @@ -414,7 +420,7 @@ class TestAsyncTaskManager: async def test_cancel_by_scope(self): """Test cancel_by_scope cancels matching tasks.""" _, _, AsyncTaskManager = get_taskmgr_classes() - + mock_app = create_mock_app() manager = AsyncTaskManager(mock_app) @@ -422,16 +428,10 @@ class TestAsyncTaskManager: await asyncio.sleep(10) # Create task with APPLICATION scope - w1 = manager.create_task( - long_coro(), - scopes=[MockLifecycleControlScope.APPLICATION] - ) + w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION]) # Create task with different scope - w2 = manager.create_task( - long_coro(), - scopes=[MockLifecycleControlScope.PIPELINE] - ) + w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE]) manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION) diff --git a/tests/unit_tests/discover/test_engine.py b/tests/unit_tests/discover/test_engine.py index 63ce82d84..0f4efc9c9 100644 --- a/tests/unit_tests/discover/test_engine.py +++ b/tests/unit_tests/discover/test_engine.py @@ -15,68 +15,68 @@ class TestI18nString: def test_create_with_english_only(self): """Create I18nString with only English.""" - i18n = I18nString(en_US="Hello") + i18n = I18nString(en_US='Hello') - assert i18n.en_US == "Hello" + assert i18n.en_US == 'Hello' assert i18n.zh_Hans is None def test_create_with_multiple_languages(self): """Create I18nString with multiple languages.""" i18n = I18nString( - en_US="Hello", - zh_Hans="你好", - zh_Hant="你好", - ja_JP="こんにちは", + en_US='Hello', + zh_Hans='你好', + zh_Hant='你好', + ja_JP='こんにちは', ) - assert i18n.en_US == "Hello" - assert i18n.zh_Hans == "你好" - assert i18n.zh_Hant == "你好" - assert i18n.ja_JP == "こんにちは" + assert i18n.en_US == 'Hello' + assert i18n.zh_Hans == '你好' + assert i18n.zh_Hant == '你好' + assert i18n.ja_JP == 'こんにちは' def test_to_dict_with_english_only(self): """to_dict returns only non-None fields.""" - i18n = I18nString(en_US="Hello") + i18n = I18nString(en_US='Hello') result = i18n.to_dict() - assert result == {"en_US": "Hello"} + assert result == {'en_US': 'Hello'} def test_to_dict_with_multiple_languages(self): """to_dict returns all non-None fields.""" i18n = I18nString( - en_US="Hello", - zh_Hans="你好", + en_US='Hello', + zh_Hans='你好', ) result = i18n.to_dict() - assert result == {"en_US": "Hello", "zh_Hans": "你好"} + assert result == {'en_US': 'Hello', 'zh_Hans': '你好'} def test_to_dict_excludes_none(self): """to_dict excludes None values.""" i18n = I18nString( - en_US="Hello", + en_US='Hello', zh_Hans=None, - ja_JP="こんにちは", + ja_JP='こんにちは', ) result = i18n.to_dict() - assert "zh_Hans" not in result - assert "en_US" in result - assert "ja_JP" in result + assert 'zh_Hans' not in result + assert 'en_US' in result + assert 'ja_JP' in result def test_to_dict_all_languages(self): """to_dict with all supported languages.""" i18n = I18nString( - en_US="Hello", - zh_Hans="你好", - zh_Hant="你好", - ja_JP="こんにちは", - th_TH="สวัสดี", - vi_VN="Xin chào", - es_ES="Hola", + en_US='Hello', + zh_Hans='你好', + zh_Hant='你好', + ja_JP='こんにちは', + th_TH='สวัสดี', + vi_VN='Xin chào', + es_ES='Hola', ) result = i18n.to_dict() @@ -92,30 +92,30 @@ class TestMetadata: from langbot.pkg.discover.engine import I18nString metadata = Metadata( - name="test-component", - label=I18nString(en_US="Test Component"), + name='test-component', + label=I18nString(en_US='Test Component'), ) - assert metadata.name == "test-component" - assert metadata.label.en_US == "Test Component" + assert metadata.name == 'test-component' + assert metadata.label.en_US == 'Test Component' def test_create_with_all_fields(self): """Create Metadata with all optional fields.""" from langbot.pkg.discover.engine import I18nString metadata = Metadata( - name="test-component", - label=I18nString(en_US="Test"), - description=I18nString(en_US="A test component"), - version="1.0.0", - icon="test-icon", - author="Test Author", - repository="https://github.com/test/repo", + name='test-component', + label=I18nString(en_US='Test'), + description=I18nString(en_US='A test component'), + version='1.0.0', + icon='test-icon', + author='Test Author', + repository='https://github.com/test/repo', ) - assert metadata.version == "1.0.0" - assert metadata.icon == "test-icon" - assert metadata.author == "Test Author" + assert metadata.version == '1.0.0' + assert metadata.icon == 'test-icon' + assert metadata.author == 'Test Author' class TestComponentManifest: diff --git a/tests/unit_tests/persistence/test_database_decorator.py b/tests/unit_tests/persistence/test_database_decorator.py index 222cd3a3c..d72d48e0b 100644 --- a/tests/unit_tests/persistence/test_database_decorator.py +++ b/tests/unit_tests/persistence/test_database_decorator.py @@ -7,6 +7,7 @@ Tests cover: Note: Uses import isolation to break circular import chains. """ + from __future__ import annotations import sys @@ -86,6 +87,7 @@ def get_database_module(): """Get database module with import isolation.""" with isolated_database_import(): from langbot.pkg.persistence import database + return database @@ -198,4 +200,4 @@ class TestManagerClassDecorator: # Create instance to test method (with mock app) mock_app = Mock() instance = ManagerWithMethods(mock_app) - assert instance.custom_method() == 'test_value' \ No newline at end of file + assert instance.custom_method() == 'test_value' diff --git a/tests/unit_tests/persistence/test_mgr_methods.py b/tests/unit_tests/persistence/test_mgr_methods.py index 2145f84eb..0c75cf2d4 100644 --- a/tests/unit_tests/persistence/test_mgr_methods.py +++ b/tests/unit_tests/persistence/test_mgr_methods.py @@ -4,6 +4,7 @@ Tests cover: - execute_async() with mock database - get_db_engine() with mock database manager """ + from __future__ import annotations import pytest @@ -85,7 +86,7 @@ class TestExecuteAsync: mock_db.get_engine = Mock(return_value=mock_engine) mgr.db = mock_db - result = await mgr.execute_async(sqlalchemy.text("SELECT 1")) + result = await mgr.execute_async(sqlalchemy.text('SELECT 1')) # Verify result is the same object returned by execute assert result is mock_result @@ -152,4 +153,4 @@ class TestSerializeModelEdgeCases: result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name']) # Result should be empty dict when all columns masked - assert result == {} \ No newline at end of file + assert result == {} diff --git a/tests/unit_tests/persistence/test_serialize_model.py b/tests/unit_tests/persistence/test_serialize_model.py index 199c3a8f2..bbab59c67 100644 --- a/tests/unit_tests/persistence/test_serialize_model.py +++ b/tests/unit_tests/persistence/test_serialize_model.py @@ -5,6 +5,7 @@ Tests cover: - datetime conversion to isoformat - masked_columns exclusion """ + from __future__ import annotations import datetime diff --git a/tests/unit_tests/pipeline/test_aggregator.py b/tests/unit_tests/pipeline/test_aggregator.py index 97ac35c38..9eab7615f 100644 --- a/tests/unit_tests/pipeline/test_aggregator.py +++ b/tests/unit_tests/pipeline/test_aggregator.py @@ -49,7 +49,7 @@ class TestPendingMessage: """PendingMessage should be created with correct fields.""" aggregator = get_aggregator_module() - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -88,7 +88,7 @@ class TestSessionBuffer: """SessionBuffer should accept initial messages.""" aggregator = get_aggregator_module() - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage: agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage: agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -419,7 +419,7 @@ class TestMessageAggregatorMerge: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -445,8 +445,8 @@ class TestMessageAggregatorMerge: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain1 = text_chain("hello") - chain2 = text_chain("world") + chain1 = text_chain('hello') + chain2 = text_chain('world') event = friend_message_event(chain1) adapter = mock_adapter() @@ -476,8 +476,8 @@ class TestMessageAggregatorMerge: # Should contain both messages with separator merged_str = str(merged.message_chain) - assert "hello" in merged_str - assert "world" in merged_str + assert 'hello' in merged_str + assert 'world' in merged_str def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self): """Merged PendingMessage should keep routed_by_rule when any input was rule-routed.""" @@ -486,8 +486,8 @@ class TestMessageAggregatorMerge: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain1 = text_chain("first") - chain2 = text_chain("second") + chain1 = text_chain('first') + chain2 = text_chain('second') event = friend_message_event(chain1) adapter = mock_adapter() @@ -545,7 +545,7 @@ class TestMessageAggregatorFlush: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() @@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll: app = make_aggregator_app() agg = aggregator.MessageAggregator(app) - chain = text_chain("hello") + chain = text_chain('hello') event = friend_message_event(chain) adapter = mock_adapter() diff --git a/tests/unit_tests/pipeline/test_chat_handler.py b/tests/unit_tests/pipeline/test_chat_handler.py index 097ef2b4a..c8a923d78 100644 --- a/tests/unit_tests/pipeline/test_chat_handler.py +++ b/tests/unit_tests/pipeline/test_chat_handler.py @@ -15,6 +15,7 @@ from tests.factories import FakeApp # ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== + @pytest.fixture(scope='module') def mock_circular_import_chain(): """ @@ -36,9 +37,11 @@ def mock_circular_import_chain(): # Create a default runner that yields a simple response class DefaultRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): yield Message(role='assistant', content='fake response') @@ -70,9 +73,12 @@ def mock_event_ctx(): @pytest.fixture def set_runner(): """Factory fixture to set a custom runner for tests.""" + def _set_runner(runner_class): import sys + sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class] + return _set_runner @@ -87,6 +93,7 @@ def get_chat_handler(): global _chat_handler_module if _chat_handler_module is None: from importlib import import_module + _chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat') return _chat_handler_module @@ -96,12 +103,14 @@ def get_entities(): global _entities_module if _entities_module is None: from importlib import import_module + _entities_module = import_module('langbot.pkg.pipeline.entities') return _entities_module # ============== REAL ChatMessageHandler Tests ============== + @pytest.mark.usefixtures('mock_circular_import_chain') class TestChatMessageHandlerReal: """Tests for real ChatMessageHandler class.""" @@ -188,9 +197,11 @@ class TestChatMessageHandlerReal: class QuickRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): yield Message(role='assistant', content='ok') @@ -222,9 +233,11 @@ class TestChatMessageHandlerReal: class SingleRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): yield Message(role='assistant', content='response') @@ -262,9 +275,11 @@ class TestChatHandlerStreaming: class StreamRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): yield MessageChunk(role='assistant', content='Hello', is_final=False) yield MessageChunk(role='assistant', content=' World', is_final=True) @@ -303,14 +318,19 @@ class TestChatHandlerExceptions: query.pipeline_config = { 'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}, + }, } class FailingRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): raise ValueError('API error') yield @@ -346,14 +366,19 @@ class TestChatHandlerExceptions: query.pipeline_config = { 'output': {'misc': {'exception-handling': 'show-error'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}, + }, } class ErrorRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): raise ValueError('Custom error') yield @@ -386,14 +411,19 @@ class TestChatHandlerExceptions: query.pipeline_config = { 'output': {'misc': {'exception-handling': 'hide'}}, - 'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}}, + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}, + }, } class HideErrorRunner: name = 'local-agent' + def __init__(self, app, config): self.app = app self.config = config + async def run(self, query): raise RuntimeError('hidden') yield @@ -433,4 +463,4 @@ class TestChatHandlerHelper: chat = get_chat_handler() handler = chat.ChatMessageHandler(fake_app) result = handler.cut_str('first line\nsecond line') - assert '...' in result \ No newline at end of file + assert '...' in result diff --git a/tests/unit_tests/pipeline/test_cntfilter.py b/tests/unit_tests/pipeline/test_cntfilter.py index 1d29d1797..f0e46b41a 100644 --- a/tests/unit_tests/pipeline/test_cntfilter.py +++ b/tests/unit_tests/pipeline/test_cntfilter.py @@ -67,7 +67,11 @@ def make_pipeline_config(**overrides): for key, value in overrides.items(): if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict): for sub_key, sub_value in value.items(): - if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict): + if ( + sub_key in base_config[key] + and isinstance(base_config[key][sub_key], dict) + and isinstance(sub_value, dict) + ): base_config[key][sub_key].update(sub_value) else: base_config[key][sub_key] = sub_value @@ -141,7 +145,7 @@ class TestPreContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello world") + query = text_query('hello world') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -163,7 +167,7 @@ class TestPreContentFilter: await stage.initialize(pipeline_config) # Empty message chain - query = text_query("") + query = text_query('') query.message_chain = platform_message.MessageChain([]) query.pipeline_config = pipeline_config @@ -185,7 +189,7 @@ class TestPreContentFilter: await stage.initialize(pipeline_config) - query = text_query(" ") # Only whitespace + query = text_query(' ') # Only whitespace query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -234,7 +238,7 @@ class TestPreContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello world") + query = text_query('hello world') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -266,7 +270,7 @@ class TestContentIgnoreFilter: await stage.initialize(pipeline_config) - query = text_query("/help me") + query = text_query('/help me') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -294,7 +298,7 @@ class TestContentIgnoreFilter: await stage.initialize(pipeline_config) - query = text_query("http://example.com") + query = text_query('http://example.com') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -322,7 +326,7 @@ class TestContentIgnoreFilter: await stage.initialize(pipeline_config) - query = text_query("normal message") + query = text_query('normal message') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -343,7 +347,7 @@ class TestContentIgnoreFilter: await stage.initialize(pipeline_config) - query = text_query("/help me") + query = text_query('/help me') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') @@ -368,12 +372,10 @@ class TestPostContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config # Add a response message - query.resp_messages = [ - provider_message.Message(role='assistant', content='Hello back!') - ] + query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')] result = await stage.process(query, 'PostContentFilterStage') @@ -398,11 +400,9 @@ class TestPostContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config - query.resp_messages = [ - provider_message.Message(role='assistant', content='Response') - ] + query.resp_messages = [provider_message.Message(role='assistant', content='Response')] result = await stage.process(query, 'PostContentFilterStage') @@ -422,7 +422,7 @@ class TestPostContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config # Non-string content - use model_construct to bypass validation # The actual content type could be a list of ContentElement objects @@ -450,11 +450,9 @@ class TestPostContentFilter: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config - query.resp_messages = [ - provider_message.Message(role='assistant', content='') - ] + query.resp_messages = [provider_message.Message(role='assistant', content='')] result = await stage.process(query, 'PostContentFilterStage') @@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config with pytest.raises(ValueError, match='未知的 stage_inst_name'): @@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect: await stage.initialize(pipeline_config) - query = text_query("normal message without prefix") + query = text_query('normal message without prefix') query.pipeline_config = pipeline_config result = await stage.process(query, 'PreContentFilterStage') diff --git a/tests/unit_tests/pipeline/test_command_handler.py b/tests/unit_tests/pipeline/test_command_handler.py index 5006d2487..00bd5b681 100644 --- a/tests/unit_tests/pipeline/test_command_handler.py +++ b/tests/unit_tests/pipeline/test_command_handler.py @@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query # ============== FIXTURE USING IMPORT ISOLATION UTILITY ============== + @pytest.fixture(scope='module') def mock_circular_import_chain(): """ @@ -56,6 +57,7 @@ def mock_event_ctx(): @pytest.fixture def mock_execute_factory(): """Factory fixture to create mock cmd_mgr.execute generators.""" + def _create_execute( text: str | None = 'ok', error: str | None = None, @@ -71,7 +73,9 @@ def mock_execute_factory(): ret.image_base64 = image_base64 ret.file_url = file_url yield ret + return mock_execute + return _create_execute @@ -86,6 +90,7 @@ def get_command_handler(): global _command_handler_module if _command_handler_module is None: from importlib import import_module + _command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command') return _command_handler_module @@ -95,12 +100,14 @@ def get_entities(): global _entities_module if _entities_module is None: from importlib import import_module + _entities_module = import_module('langbot.pkg.pipeline.entities') return _entities_module # ============== REAL CommandHandler Tests ============== + @pytest.mark.usefixtures('mock_circular_import_chain') class TestCommandHandlerReal: """Tests for real CommandHandler class.""" @@ -127,6 +134,7 @@ class TestCommandHandlerReal: fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) executed_commands = [] + async def track_execute(command_text, full_command_text, query, session): executed_commands.append(command_text) ret = Mock() @@ -334,8 +342,7 @@ class TestCommandHandlerReal: command = get_command_handler() fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) fake_app.cmd_mgr.execute = mock_execute_factory( - text='Here is the image:', - image_url='https://example.com/image.png' + text='Here is the image:', image_url='https://example.com/image.png' ) handler = command.CommandHandler(fake_app) @@ -393,4 +400,4 @@ class TestCommandHandlerHelper: command = get_command_handler() handler = command.CommandHandler(fake_app) result = handler.cut_str('first line\nsecond line') - assert '...' in result \ No newline at end of file + assert '...' in result diff --git a/tests/unit_tests/pipeline/test_longtext.py b/tests/unit_tests/pipeline/test_longtext.py index 1595cc18c..3a693276f 100644 --- a/tests/unit_tests/pipeline/test_longtext.py +++ b/tests/unit_tests/pipeline/test_longtext.py @@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config - query.resp_message_chain = [ - platform_message.MessageChain([platform_message.Plain(text="very long response")]) - ] + query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])] result = await stage.process(query, 'LongTextProcessStage') @@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config - query.resp_message_chain = [ - platform_message.MessageChain([platform_message.Plain(text="short response")]) - ] + query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])] result = await stage.process(query, 'LongTextProcessStage') @@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config # Non-Plain component (Image) query.resp_message_chain = [ - platform_message.MessageChain([ - platform_message.Plain(text="short"), - platform_message.Image(url="https://example.com/img.png") - ]) + platform_message.MessageChain( + [platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')] + ) ] result = await stage.process(query, 'LongTextProcessStage') @@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] @@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess: stage = longtext.LongTextProcessStage(app) stage.strategy_impl = AsyncMock() - query = text_query("hello") + query = text_query('hello') query.pipeline_config = make_longtext_config(strategy='forward', threshold=1) query.resp_message_chain = [] @@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess: assert result.new_query is query stage.strategy_impl.process.assert_not_called() + class TestForwardStrategy: """Tests for ForwardComponentStrategy.""" @@ -260,7 +256,7 @@ class TestForwardStrategy: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config # Create a mock adapter with bot_account_id mock_adapter = Mock() @@ -268,10 +264,8 @@ class TestForwardStrategy: query.adapter = mock_adapter # Long text exceeding threshold - long_text = "This is a very long response that exceeds the threshold" - query.resp_message_chain = [ - platform_message.MessageChain([platform_message.Plain(text=long_text)]) - ] + long_text = 'This is a very long response that exceeds the threshold' + query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])] result = await stage.process(query, 'LongTextProcessStage') @@ -297,13 +291,13 @@ class TestForwardStrategy: await strat.initialize() - query = text_query("hello") + query = text_query('hello') query.pipeline_config = make_longtext_config() mock_adapter = Mock() mock_adapter.bot_account_id = '12345' query.adapter = mock_adapter - components = await strat.process("test message", query) + components = await strat.process('test message', query) assert len(components) == 1 assert isinstance(components[0], platform_message.Forward) @@ -326,14 +320,12 @@ class TestLongTextThreshold: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config # Text below threshold - short_text = "x" * (threshold - 1) - query.resp_message_chain = [ - platform_message.MessageChain([platform_message.Plain(text=short_text)]) - ] + short_text = 'x' * (threshold - 1) + query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])] result = await stage.process(query, 'LongTextProcessStage') diff --git a/tests/unit_tests/pipeline/test_msgtrun.py b/tests/unit_tests/pipeline/test_msgtrun.py index 9cfdababf..4470c6945 100644 --- a/tests/unit_tests/pipeline/test_msgtrun.py +++ b/tests/unit_tests/pipeline/test_msgtrun.py @@ -115,7 +115,7 @@ class TestRoundTruncatorProcess: await stage.initialize(pipeline_config) # Create query with 3 messages (within limit) - query = text_query("current message") + query = text_query('current message') query.pipeline_config = pipeline_config query.messages = [ provider_message.Message(role='user', content='message 1'), @@ -154,7 +154,7 @@ class TestRoundTruncatorProcess: # Create query with many messages exceeding limit # 7 messages = 3 full rounds + 1 current user - query = text_query("current message") + query = text_query('current message') query.pipeline_config = pipeline_config query.messages = [ provider_message.Message(role='user', content='message 1'), @@ -194,7 +194,7 @@ class TestRoundTruncatorProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.messages = [] @@ -216,7 +216,7 @@ class TestRoundTruncatorProcess: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.messages = [ provider_message.Message(role='user', content='hello'), @@ -240,7 +240,7 @@ class TestRoundTruncatorProcess: await stage.initialize(pipeline_config) - query = text_query("current") + query = text_query('current') query.pipeline_config = pipeline_config query.messages = [ provider_message.Message(role='user', content='user1'), @@ -274,7 +274,7 @@ class TestRoundTruncatorProcess: await stage.initialize(pipeline_config) - query = text_query("current") + query = text_query('current') query.pipeline_config = pipeline_config query.messages = [ provider_message.Message(role='user', content='old1'), @@ -305,7 +305,7 @@ class TestRoundTruncatorDirect: trun = trun_cls(app) break - query = text_query("hello") + query = text_query('hello') query.pipeline_config = make_truncate_config(max_round=3) query.messages = [ provider_message.Message(role='user', content='m1'), diff --git a/tests/unit_tests/pipeline/test_preproc.py b/tests/unit_tests/pipeline/test_preproc.py index 1413f5f74..9cbf65265 100644 --- a/tests/unit_tests/pipeline/test_preproc.py +++ b/tests/unit_tests/pipeline/test_preproc.py @@ -78,7 +78,7 @@ class TestPreProcessorNormalText: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = text_query("hello world") + query = text_query('hello world') result = await stage.process(query, 'PreProcessor') @@ -113,7 +113,7 @@ class TestPreProcessorNormalText: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = text_query("test message") + query = text_query('test message') result = await stage.process(query, 'PreProcessor') @@ -194,13 +194,16 @@ class TestPreProcessorImageSegment: stage = preproc.PreProcessor(app) # Image query with base64 - query = image_query(text="look at this", url=None) + query = image_query(text='look at this', url=None) # Set base64 on the image component import langbot_plugin.api.entities.builtin.platform.message as platform_message - chain = platform_message.MessageChain([ - platform_message.Plain(text="look at this"), - platform_message.Image(base64="data:image/png;base64,abc123"), - ]) + + chain = platform_message.MessageChain( + [ + platform_message.Plain(text='look at this'), + platform_message.Image(base64='data:image/png;base64,abc123'), + ] + ) query.message_chain = chain result = await stage.process(query, 'PreProcessor') @@ -238,7 +241,7 @@ class TestPreProcessorImageSegment: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = image_query(text="describe this") + query = image_query(text='describe this') result = await stage.process(query, 'PreProcessor') @@ -276,7 +279,7 @@ class TestPreProcessorModelSelection: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = text_query("hello") + query = text_query('hello') # Set pipeline config with primary model query.pipeline_config = { @@ -335,7 +338,7 @@ class TestPreProcessorModelSelection: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = { 'ai': { @@ -384,7 +387,7 @@ class TestPreProcessorVariables: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = text_query("hello", sender_id=67890) + query = text_query('hello', sender_id=67890) result = await stage.process(query, 'PreProcessor') @@ -421,7 +424,7 @@ class TestPreProcessorVariables: app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) stage = preproc.PreProcessor(app) - query = group_text_query("hello", group_id=99999) + query = group_text_query('hello', group_id=99999) result = await stage.process(query, 'PreProcessor') diff --git a/tests/unit_tests/pipeline/test_ratelimit.py b/tests/unit_tests/pipeline/test_ratelimit.py index a06c3b674..be767ad56 100644 --- a/tests/unit_tests/pipeline/test_ratelimit.py +++ b/tests/unit_tests/pipeline/test_ratelimit.py @@ -46,7 +46,7 @@ class TestFixedWindowAlgo: 'safety': { 'rate-limit': { 'window-length': 60, # 60 seconds window - 'limitation': 10, # 10 requests per window + 'limitation': 10, # 10 requests per window 'strategy': 'drop', } } @@ -75,11 +75,9 @@ class TestFixedWindowAlgo: # Make requests within limit for i in range(10): result = await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - '12345' + sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345' ) - assert result is True, f"Request {i+1} should be allowed" + assert result is True, f'Request {i + 1} should be allowed' @pytest.mark.asyncio async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit): @@ -91,20 +89,12 @@ class TestFixedWindowAlgo: # Exhaust the limit for i in range(10): - await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - '12345' - ) + await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') # Next request should be denied - result = await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - '12345' - ) + result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') - assert result is False, "Request exceeding limit should be denied" + assert result is False, 'Request exceeding limit should be denied' @pytest.mark.asyncio async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit): @@ -116,20 +106,14 @@ class TestFixedWindowAlgo: # Exhaust limit for session 1 for i in range(10): - await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - 'session1' - ) + await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1') # Session 2 should still have its own limit result = await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - 'session2' + sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2' ) - assert result is True, "Different session should have independent limit" + assert result is True, 'Different session should have independent limit' @pytest.mark.asyncio async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query): @@ -150,19 +134,11 @@ class TestFixedWindowAlgo: await algo.initialize() # First request allowed - result1 = await algo.require_access( - sample_query, - provider_session.LauncherTypes.PERSON, - '12345' - ) + result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345') assert result1 is True # Second request denied - result2 = await algo.require_access( - sample_query, - provider_session.LauncherTypes.PERSON, - '12345' - ) + result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345') assert result2 is False @pytest.mark.asyncio @@ -174,11 +150,7 @@ class TestFixedWindowAlgo: await algo.initialize() # First request creates container - await algo.require_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - '12345' - ) + await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') # Key format: 'LauncherTypes.PERSON_12345' (enum string representation) expected_key = 'LauncherTypes.PERSON_12345' @@ -230,7 +202,7 @@ class TestFixedWindowAlgo: # New request should be allowed (new window) result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test') - assert result is True, "New window should allow new requests" + assert result is True, 'New window should allow new requests' @pytest.mark.asyncio async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query): @@ -256,29 +228,21 @@ class TestFixedWindowAlgo: # First request allowed start_time = time.time() - result1 = await algo.require_access( - sample_query, - provider_session.LauncherTypes.PERSON, - 'wait_test' - ) + result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') assert result1 is True # Exhaust limit await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') # Third request should wait and then succeed - result3 = await algo.require_access( - sample_query, - provider_session.LauncherTypes.PERSON, - 'wait_test' - ) + result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test') elapsed = time.time() - start_time - assert result3 is True, "After wait, request should succeed" + assert result3 is True, 'After wait, request should succeed' # Should have waited approximately until next window # With 1-second window, elapsed should be > 0.5 second (allowing for timing variance) # Note: This is a timing-sensitive test, so we use a generous tolerance - assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s" + assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s' @pytest.mark.asyncio async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit): @@ -289,11 +253,7 @@ class TestFixedWindowAlgo: await algo.initialize() # release_access is empty in current implementation - await algo.release_access( - sample_query_with_rate_limit, - provider_session.LauncherTypes.PERSON, - '12345' - ) + await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345') # Should not raise or change state assert 'person_12345' not in algo.containers diff --git a/tests/unit_tests/pipeline/test_wrapper.py b/tests/unit_tests/pipeline/test_wrapper.py index e5d47c76f..8dea6c8bb 100644 --- a/tests/unit_tests/pipeline/test_wrapper.py +++ b/tests/unit_tests/pipeline/test_wrapper.py @@ -55,7 +55,7 @@ def make_session(): launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=12345, sender_id=12345, - use_prompt_name="default", + use_prompt_name='default', using_conversation=None, conversations=[], ) @@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config - query.resp_messages = [ - platform_message.MessageChain([platform_message.Plain(text="response")]) - ] + query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])] query.resp_message_chain = [] results = [] @@ -125,7 +123,7 @@ class TestResponseWrapperCommand: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] @@ -133,7 +131,7 @@ class TestResponseWrapperCommand: command_resp = Mock() command_resp.role = 'command' command_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')]) ) query.resp_messages = [command_resp] @@ -163,7 +161,7 @@ class TestResponseWrapperPlugin: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] @@ -171,7 +169,7 @@ class TestResponseWrapperPlugin: plugin_resp = Mock() plugin_resp.role = 'plugin' plugin_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')]) ) query.resp_messages = [plugin_resp] @@ -211,17 +209,17 @@ class TestResponseWrapperAssistant: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] # Create assistant response with content assistant_resp = Mock() assistant_resp.role = 'assistant' - assistant_resp.content = "Hello back!" + assistant_resp.content = 'Hello back!' assistant_resp.tool_calls = None assistant_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')]) ) query.resp_messages = [assistant_resp] @@ -247,7 +245,7 @@ class TestResponseWrapperAssistant: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] @@ -292,7 +290,7 @@ class TestResponseWrapperAssistant: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] @@ -303,10 +301,10 @@ class TestResponseWrapperAssistant: assistant_resp = Mock() assistant_resp.role = 'assistant' - assistant_resp.content = "Processing..." + assistant_resp.content = 'Processing...' assistant_resp.tool_calls = [mock_tool_call] assistant_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')]) ) query.resp_messages = [assistant_resp] @@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] # Create assistant response with content assistant_resp = Mock() assistant_resp.role = 'assistant' - assistant_resp.content = "Hello!" + assistant_resp.content = 'Hello!' assistant_resp.tool_calls = None assistant_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')]) ) query.resp_messages = [assistant_resp] @@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply: app.sess_mgr.get_session = AsyncMock(return_value=session) # Mock plugin connector with custom reply - custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")]) + custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')]) mock_event_ctx = Mock() mock_event_ctx.is_prevented_default = Mock(return_value=False) mock_event_ctx.event = Mock() @@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] # Create assistant response assistant_resp = Mock() assistant_resp.role = 'assistant' - assistant_resp.content = "Default reply" + assistant_resp.content = 'Default reply' assistant_resp.tool_calls = None assistant_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')]) ) query.resp_messages = [assistant_resp] @@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply: assert len(results[0].new_query.resp_message_chain) == 1 # Should be the custom chain chain = results[0].new_query.resp_message_chain[0] - assert "Custom reply" in str(chain) + assert 'Custom reply' in str(chain) class TestResponseWrapperVariables: @@ -452,7 +450,7 @@ class TestResponseWrapperVariables: await stage.initialize(pipeline_config) - query = text_query("hello") + query = text_query('hello') query.pipeline_config = pipeline_config query.resp_message_chain = [] query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2'] @@ -460,10 +458,10 @@ class TestResponseWrapperVariables: # Create assistant response assistant_resp = Mock() assistant_resp.role = 'assistant' - assistant_resp.content = "Hello" + assistant_resp.content = 'Hello' assistant_resp.tool_calls = None assistant_resp.get_content_platform_message_chain = Mock( - return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")]) + return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')]) ) query.resp_messages = [assistant_resp] diff --git a/tests/unit_tests/plugin/test_connector_methods.py b/tests/unit_tests/plugin/test_connector_methods.py index 10ce24191..5f09ce5a4 100644 --- a/tests/unit_tests/plugin/test_connector_methods.py +++ b/tests/unit_tests/plugin/test_connector_methods.py @@ -6,6 +6,7 @@ Tests cover: - RAG methods (ingest, retrieve, schema) - Disabled plugin early returns """ + from __future__ import annotations import pytest @@ -86,16 +87,12 @@ class TestListPlugins: return_value=[ { 'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}}, - 'components': [ - {'manifest': {'manifest': {'kind': 'Command'}}} - ], + 'components': [{'manifest': {'manifest': {'kind': 'Command'}}}], 'debug': False, }, { 'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}}, - 'components': [ - {'manifest': {'manifest': {'kind': 'Tool'}}} - ], + 'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}], 'debug': False, }, ] @@ -127,9 +124,7 @@ class TestListPlugins: }, ] ) - connector.ap.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(__iter__=lambda self: iter([])) - ) + connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([]))) result = await connector.list_plugins() @@ -230,7 +225,8 @@ class TestCallParser: ) connector.handler.parse_document.assert_called_once_with( - 'author', 'parser', + 'author', + 'parser', {'mime_type': 'text/plain', 'filename': 'test.txt'}, b'file content', ) @@ -251,9 +247,7 @@ class TestRAGMethods: result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'}) - connector.handler.rag_ingest_document.assert_called_once_with( - 'author', 'engine', {'file': 'test.pdf'} - ) + connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'}) assert result['status'] == 'success' @pytest.mark.asyncio @@ -264,14 +258,16 @@ class TestRAGMethods: connector.handler = AsyncMock() connector.handler.retrieve_knowledge = AsyncMock( - return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]} + return_value={ + 'results': [ + {'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1} + ] + } ) result = await connector.call_rag_retrieve('author/engine', {'query': 'test'}) - connector.handler.retrieve_knowledge.assert_called_once_with( - 'author', 'engine', '', {'query': 'test'} - ) + connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'}) assert result == { 'results': [ { @@ -290,9 +286,7 @@ class TestRAGMethods: connector = create_mock_connector() connector.handler = AsyncMock() - connector.handler.get_rag_creation_schema = AsyncMock( - return_value={'properties': {'name': {'type': 'string'}}} - ) + connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}}) result = await connector.get_rag_creation_schema('author/engine') @@ -326,9 +320,7 @@ class TestRAGMethods: await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'}) - connector.handler.rag_on_kb_create.assert_called_once_with( - 'author', 'engine', 'kb-uuid', {'model': 'test'} - ) + connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'}) @pytest.mark.asyncio async def test_rag_on_kb_delete(self): @@ -354,9 +346,7 @@ class TestRAGMethods: result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid') - connector.handler.rag_delete_document.assert_called_once_with( - 'author', 'engine', 'doc-uuid', 'kb-uuid' - ) + connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid') assert result is True @@ -446,9 +436,7 @@ class TestGetPluginInfo: connector = create_mock_connector() connector.handler = AsyncMock() - connector.handler.get_plugin_info = AsyncMock( - return_value={'manifest': {'metadata': {'name': 'plugin'}}} - ) + connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}}) result = await connector.get_plugin_info('author', 'plugin') @@ -470,9 +458,7 @@ class TestSetPluginConfig: await connector.set_plugin_config('author', 'plugin', {'setting': 'value'}) - connector.handler.set_plugin_config.assert_called_once_with( - 'author', 'plugin', {'setting': 'value'} - ) + connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'}) class TestPingPluginRuntime: diff --git a/tests/unit_tests/plugin/test_connector_static.py b/tests/unit_tests/plugin/test_connector_static.py index 77747b7b8..8c88b9707 100644 --- a/tests/unit_tests/plugin/test_connector_static.py +++ b/tests/unit_tests/plugin/test_connector_static.py @@ -3,6 +3,7 @@ Tests cover: - _parse_plugin_id() parsing and validation """ + from __future__ import annotations import pytest diff --git a/tests/unit_tests/plugin/test_extract_deps.py b/tests/unit_tests/plugin/test_extract_deps.py index e9c30ec99..0980f4cc3 100644 --- a/tests/unit_tests/plugin/test_extract_deps.py +++ b/tests/unit_tests/plugin/test_extract_deps.py @@ -6,6 +6,7 @@ Tests cover: - Handling missing requirements.txt - Handling empty/malformed requirements.txt """ + from __future__ import annotations import zipfile @@ -82,13 +83,13 @@ class TestExtractDepsMetadata: """Test that comments and empty lines are filtered.""" connector_instance = create_mock_connector() - requirements = '''# This is a comment + requirements = """# This is a comment requests>=2.0 # Another comment flask==1.0 -numpy''' +numpy""" zip_bytes = create_zip_with_requirements(requirements) task_context = Mock() @@ -147,9 +148,9 @@ numpy''' """Test handling requirements.txt with only comments.""" connector_instance = create_mock_connector() - requirements = '''# Comment 1 + requirements = """# Comment 1 # Comment 2 -# Comment 3''' +# Comment 3""" zip_bytes = create_zip_with_requirements(requirements) task_context = Mock() diff --git a/tests/unit_tests/plugin/test_handler.py b/tests/unit_tests/plugin/test_handler.py index 44522ef46..989a333a4 100644 --- a/tests/unit_tests/plugin/test_handler.py +++ b/tests/unit_tests/plugin/test_handler.py @@ -40,11 +40,13 @@ class TestHandlerQueryVariables: """Test set_query_var returns error when query not found.""" runtime_handler = make_handler(mock_app) - response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({ - 'query_id': 'nonexistent-query', - 'key': 'test_var', - 'value': 'test_value', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]( + { + 'query_id': 'nonexistent-query', + 'key': 'test_var', + 'value': 'test_value', + } + ) assert response.code != 0 assert 'nonexistent-query' in response.message @@ -58,11 +60,13 @@ class TestHandlerQueryVariables: mock_app.query_pool.cached_queries['test-query'] = mock_query - response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({ - 'query_id': 'test-query', - 'key': 'test_var', - 'value': 'test_value', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]( + { + 'query_id': 'test-query', + 'key': 'test_var', + 'value': 'test_value', + } + ) assert response.code == 0 assert mock_query.variables['test_var'] == 'test_value' @@ -76,10 +80,12 @@ class TestHandlerQueryVariables: mock_app.query_pool.cached_queries['test-query'] = mock_query - response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({ - 'query_id': 'test-query', - 'key': 'existing_var', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]( + { + 'query_id': 'test-query', + 'key': 'existing_var', + } + ) assert response.code == 0 assert response.data == {'value': 'existing_value'} @@ -93,9 +99,11 @@ class TestHandlerQueryVariables: mock_app.query_pool.cached_queries['test-query'] = mock_query - response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({ - 'query_id': 'test-query', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]( + { + 'query_id': 'test-query', + } + ) assert response.code == 0 assert response.data == {'vars': mock_query.variables} @@ -108,7 +116,7 @@ class TestHandlerRagErrorResponse: """Test basic error response creation.""" from langbot.pkg.plugin.handler import _make_rag_error_response - error = Exception("test error") + error = Exception('test error') response = _make_rag_error_response(error, 'TestError') # ActionResponse is a pydantic model, check message field @@ -120,13 +128,8 @@ class TestHandlerRagErrorResponse: """Test error response with extra context.""" from langbot.pkg.plugin.handler import _make_rag_error_response - error = ValueError("invalid input") - response = _make_rag_error_response( - error, - 'ValidationError', - field='name', - value='test' - ) + error = ValueError('invalid input') + response = _make_rag_error_response(error, 'ValidationError', field='name', value='test') assert 'ValidationError' in response.message assert 'field=name' in response.message @@ -137,7 +140,7 @@ class TestHandlerRagErrorResponse: """Test error response includes exception type.""" from langbot.pkg.plugin.handler import _make_rag_error_response - error = RuntimeError("connection failed") + error = RuntimeError('connection failed') response = _make_rag_error_response(error, 'ConnectionError') assert 'RuntimeError' in response.message @@ -148,7 +151,7 @@ class TestHandlerRagErrorResponse: """Test error response with no extra context.""" from langbot.pkg.plugin.handler import _make_rag_error_response - error = KeyError("missing_key") + error = KeyError('missing_key') response = _make_rag_error_response(error, 'LookupError') # No context parts means no brackets diff --git a/tests/unit_tests/plugin/test_handler_actions.py b/tests/unit_tests/plugin/test_handler_actions.py index 81bc75705..490ce0b22 100644 --- a/tests/unit_tests/plugin/test_handler_actions.py +++ b/tests/unit_tests/plugin/test_handler_actions.py @@ -47,12 +47,14 @@ class TestInitializePluginSettings: Mock(), ] - response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({ - 'plugin_author': 'test-author', - 'plugin_name': 'test-plugin', - 'install_source': 'local', - 'install_info': {'path': '/test'}, - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]( + { + 'plugin_author': 'test-author', + 'plugin_name': 'test-plugin', + 'install_source': 'local', + 'install_info': {'path': '/test'}, + } + ) assert response.code == 0 assert app.persistence_mgr.execute_async.await_count == 2 @@ -82,12 +84,14 @@ class TestInitializePluginSettings: Mock(), ] - response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({ - 'plugin_author': 'test-author', - 'plugin_name': 'test-plugin', - 'install_source': 'github', - 'install_info': {'repo': 'author/name'}, - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]( + { + 'plugin_author': 'test-author', + 'plugin_name': 'test-plugin', + 'install_source': 'github', + 'install_info': {'repo': 'author/name'}, + } + ) assert response.code == 0 assert app.persistence_mgr.execute_async.await_count == 3 @@ -161,9 +165,7 @@ class TestSetBinaryStorage: runtime_handler = make_handler(app) app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old')) - response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value]( - self.payload(b'new') - ) + response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new')) assert response.code == 0 assert app.persistence_mgr.execute_async.await_count == 2 @@ -203,9 +205,7 @@ class TestSetBinaryStorage: runtime_handler = make_handler(app) app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0 - response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value]( - self.payload(b'x') - ) + response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x')) assert response.code != 0 assert '1 > 0 bytes' in response.message @@ -228,10 +228,12 @@ class TestGetPluginSettings: runtime_handler = make_handler(app) app.persistence_mgr.execute_async.return_value = make_result() - response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({ - 'plugin_author': 'test-author', - 'plugin_name': 'test-plugin', - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]( + { + 'plugin_author': 'test-author', + 'plugin_name': 'test-plugin', + } + ) assert response.code == 0 assert response.data == { @@ -255,10 +257,12 @@ class TestGetPluginSettings: ) app.persistence_mgr.execute_async.return_value = make_result(setting) - response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({ - 'plugin_author': 'test-author', - 'plugin_name': 'test-plugin', - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]( + { + 'plugin_author': 'test-author', + 'plugin_name': 'test-plugin', + } + ) assert response.code == 0 assert response.data == { @@ -286,11 +290,13 @@ class TestGetBinaryStorage: runtime_handler = make_handler(app) app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content')) - response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({ - 'key': 'test-key', - 'owner_type': 'plugin', - 'owner': 'test-owner', - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]( + { + 'key': 'test-key', + 'owner_type': 'plugin', + 'owner': 'test-owner', + } + ) assert response.code == 0 assert response.data == { @@ -303,11 +309,13 @@ class TestGetBinaryStorage: runtime_handler = make_handler(app) app.persistence_mgr.execute_async.return_value = make_result() - response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({ - 'key': 'test-key', - 'owner_type': 'plugin', - 'owner': 'test-owner', - }) + response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]( + { + 'key': 'test-key', + 'owner_type': 'plugin', + 'owner': 'test-owner', + } + ) assert response.code != 0 assert 'Storage with key test-key not found' in response.message @@ -329,9 +337,11 @@ class TestHandlerQueryLookup: """Query-bound actions return error when query_id is not cached.""" runtime_handler = make_handler(app) - response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({ - 'query_id': 'nonexistent-query', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]( + { + 'query_id': 'nonexistent-query', + } + ) assert response.code != 0 assert 'nonexistent-query' in response.message @@ -343,9 +353,11 @@ class TestHandlerQueryLookup: query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid') app.query_pool.cached_queries['existing-query'] = query - response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({ - 'query_id': 'existing-query', - }) + response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]( + { + 'query_id': 'existing-query', + } + ) assert response.code == 0 assert response.data == {'bot_uuid': 'test-bot-uuid'} diff --git a/tests/unit_tests/plugin/test_handler_helpers.py b/tests/unit_tests/plugin/test_handler_helpers.py index 81bbe010e..dc86b7a5f 100644 --- a/tests/unit_tests/plugin/test_handler_helpers.py +++ b/tests/unit_tests/plugin/test_handler_helpers.py @@ -4,6 +4,7 @@ Tests cover: - _make_rag_error_response() helper function - RuntimeConnectionHandler cleanup_plugin_data method """ + from __future__ import annotations import pytest @@ -23,7 +24,7 @@ class TestMakeRagErrorResponse: """Test basic error response creation.""" handler = get_handler_module() - error = ValueError("test error message") + error = ValueError('test error message') result = handler._make_rag_error_response(error, 'TestError') # ActionResponse.error() returns code=1 (error status) @@ -36,7 +37,7 @@ class TestMakeRagErrorResponse: """Test that error type is included in message.""" handler = get_handler_module() - error = RuntimeError("something went wrong") + error = RuntimeError('something went wrong') result = handler._make_rag_error_response(error, 'VectorStoreError') assert '[VectorStoreError/RuntimeError]' in result.message @@ -45,7 +46,7 @@ class TestMakeRagErrorResponse: """Test that extra context fields are included.""" handler = get_handler_module() - error = Exception("embedding failed") + error = Exception('embedding failed') result = handler._make_rag_error_response( error, 'EmbeddingError', @@ -71,7 +72,7 @@ class TestMakeRagErrorResponse: """Test multiple context fields are comma separated.""" handler = get_handler_module() - error = IOError("file not found") + error = IOError('file not found') result = handler._make_rag_error_response( error, 'FileServiceError', @@ -119,9 +120,7 @@ class TestCleanupPluginData: handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler) handler_instance.ap = mock_app - await handler_module.RuntimeConnectionHandler.cleanup_plugin_data( - handler_instance, 'author', 'plugin-name' - ) + await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name') # Should have at least 2 calls: one for settings, one for binary storage - assert mock_app.persistence_mgr.execute_async.call_count >= 2 \ No newline at end of file + assert mock_app.persistence_mgr.execute_async.call_count >= 2 diff --git a/tests/unit_tests/provider/conftest.py b/tests/unit_tests/provider/conftest.py index 71dd5cd89..13b44fd14 100644 --- a/tests/unit_tests/provider/conftest.py +++ b/tests/unit_tests/provider/conftest.py @@ -88,7 +88,10 @@ class AnotherFakeRequester(requester.ProviderAPIRequester): async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False): import langbot_plugin.api.entities.builtin.provider.message as provider_message - return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]) + + return provider_message.Message( + role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')] + ) async def invoke_rerank(self, model, query: str, documents: list, extra_args={}): """Return fake rerank results.""" @@ -135,8 +138,10 @@ def mock_app_for_modelmgr(): # Fake persistence manager - returns empty results by default app.persistence_mgr = SimpleNamespace() + async def default_execute(query): return _make_mock_result([]) + app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute) # Fake discover engine @@ -165,9 +170,7 @@ def fake_requester_registry(mock_app_for_modelmgr): fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester) another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester) - app.discover.get_components_by_kind = Mock( - return_value=[fake_component, another_component] - ) + app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component]) model_mgr = ModelManager(app) return model_mgr diff --git a/tests/unit_tests/provider/runners/test_difysvapi_runner.py b/tests/unit_tests/provider/runners/test_difysvapi_runner.py index b00c9a10a..366ef6d87 100644 --- a/tests/unit_tests/provider/runners/test_difysvapi_runner.py +++ b/tests/unit_tests/provider/runners/test_difysvapi_runner.py @@ -26,7 +26,7 @@ class TestDifyExtractTextOutput: 'base-url': 'https://api.dify.ai', } }, - 'output': {'misc': {}} + 'output': {'misc': {}}, } runner = DifyServiceAPIRunner(mock_app, pipeline_config) @@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation: 'base-url': 'https://api.dify.ai', } }, - 'output': {'misc': {}} + 'output': {'misc': {}}, } with pytest.raises(DifyAPIError, match='不支持'): @@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation: 'base-url': 'https://api.dify.ai', } }, - 'output': {'misc': {}} + 'output': {'misc': {}}, } runner = DifyServiceAPIRunner(mock_app, pipeline_config) @@ -160,10 +160,10 @@ class TestDifyRunnerInit: 'base-url': 'https://api.dify.ai', } }, - 'output': {'misc': {}} + 'output': {'misc': {}}, } runner = DifyServiceAPIRunner(mock_app, pipeline_config) assert runner.pipeline_config == pipeline_config - assert runner.ap == mock_app \ No newline at end of file + assert runner.ap == mock_app diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index abe0cf498..91d00b19f 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -1062,9 +1062,7 @@ class TestScanModels: with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info: mock_get_model_info.side_effect = ( - lambda model: {'max_input_tokens': 131072} - if model == 'moonshot/moonshot-v1-128k' - else {} + lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {} ) assert requester._safe_context_length('moonshot-v1-128k') == 131072 diff --git a/tests/unit_tests/provider/test_model_manager.py b/tests/unit_tests/provider/test_model_manager.py index b6e82d3fb..015fd5450 100644 --- a/tests/unit_tests/provider/test_model_manager.py +++ b/tests/unit_tests/provider/test_model_manager.py @@ -635,7 +635,9 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry): @pytest.mark.asyncio -async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider): +async def test_model_manager_load_llm_model_with_provider( + fake_requester_registry, fake_persistence_data, runtime_provider +): """Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel.""" model_mgr = fake_requester_registry @@ -648,7 +650,9 @@ async def test_model_manager_load_llm_model_with_provider(fake_requester_registr @pytest.mark.asyncio -async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider): +async def test_model_manager_load_llm_model_with_provider_from_row( + fake_requester_registry, fake_persistence_data, runtime_provider +): """Test ModelManager.load_llm_model_with_provider handles Row objects.""" model_mgr = fake_requester_registry @@ -661,7 +665,9 @@ async def test_model_manager_load_llm_model_with_provider_from_row(fake_requeste @pytest.mark.asyncio -async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider): +async def test_model_manager_load_embedding_model_with_provider( + fake_requester_registry, fake_persistence_data, runtime_provider +): """Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel.""" model_mgr = fake_requester_registry diff --git a/tests/unit_tests/provider/test_requester_base.py b/tests/unit_tests/provider/test_requester_base.py index c34556cdb..71c0da653 100644 --- a/tests/unit_tests/provider/test_requester_base.py +++ b/tests/unit_tests/provider/test_requester_base.py @@ -43,6 +43,7 @@ class TestableRequester(requester.ProviderAPIRequester): remove_think=False, ): import langbot_plugin.api.entities.builtin.provider.message as provider_message + return provider_message.Message( role='assistant', content=[provider_message.ContentElement(type='text', text='Testable response')], @@ -289,7 +290,9 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l current_stage_name=None, ) - messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])] + messages = [ + provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')]) + ] result = await provider.invoke_llm(query, runtime_llm_model, messages) @@ -330,7 +333,9 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider current_stage_name=None, ) - messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])] + messages = [ + provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')]) + ] chunks = [] async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages): @@ -576,7 +581,9 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg current_stage_name=None, ) - messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])] + messages = [ + provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')]) + ] with pytest.raises(RequesterError): await provider.invoke_llm(query, model, messages) diff --git a/tests/unit_tests/provider/test_session_manager.py b/tests/unit_tests/provider/test_session_manager.py index 4698bc494..eca8cac8a 100644 --- a/tests/unit_tests/provider/test_session_manager.py +++ b/tests/unit_tests/provider/test_session_manager.py @@ -5,6 +5,7 @@ Tests cover: - Conversation creation with prompts - Session concurrency semaphore """ + from __future__ import annotations import pytest @@ -60,11 +61,7 @@ class TestSessionManagerGetSession: """Create mock app with instance config.""" mock_app = Mock() mock_app.instance_config = Mock() - mock_app.instance_config.data = { - 'concurrency': { - 'session': 5 - } - } + mock_app.instance_config.data = {'concurrency': {'session': 5}} return mock_app @pytest.fixture @@ -173,11 +170,7 @@ class TestSessionManagerGetConversation: """Create mock app with instance config.""" mock_app = Mock() mock_app.instance_config = Mock() - mock_app.instance_config.data = { - 'concurrency': { - 'session': 5 - } - } + mock_app.instance_config.data = {'concurrency': {'session': 5}} return mock_app @pytest.fixture @@ -201,17 +194,13 @@ class TestSessionManagerGetConversation: return query @pytest.mark.asyncio - async def test_creates_conversation_with_prompt( - self, mock_app_with_config, sample_query, sample_session - ): + async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session): """Test that get_conversation creates conversation with prompt.""" sessionmgr = get_session_module() manager = sessionmgr.SessionManager(mock_app_with_config) - prompt_config = [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ] + prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] pipeline_uuid = 'pipeline-123' bot_uuid = 'bot-123' @@ -234,21 +223,15 @@ class TestSessionManagerGetConversation: manager = sessionmgr.SessionManager(mock_app_with_config) - prompt_config = [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ] + prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] pipeline_uuid = 'pipeline-123' bot_uuid = 'bot-123' # First call creates conversation - conv1 = await manager.get_conversation( - sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid - ) + conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid) # Second call with same pipeline should return same conversation - conv2 = await manager.get_conversation( - sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid - ) + conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid) assert conv1 is conv2 assert len(sample_session.conversations) == 1 @@ -262,36 +245,26 @@ class TestSessionManagerGetConversation: manager = sessionmgr.SessionManager(mock_app_with_config) - prompt_config = [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ] + prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] # First call with pipeline1 - conv1 = await manager.get_conversation( - sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1' - ) + conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1') # Second call with different pipeline should create new conversation - conv2 = await manager.get_conversation( - sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2' - ) + conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2') assert conv1 is not conv2 assert len(sample_session.conversations) == 2 assert sample_session.using_conversation is conv2 @pytest.mark.asyncio - async def test_conversation_has_empty_messages( - self, mock_app_with_config, sample_query, sample_session - ): + async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session): """Test that created conversation has empty messages list.""" sessionmgr = get_session_module() manager = sessionmgr.SessionManager(mock_app_with_config) - prompt_config = [ - {'role': 'system', 'content': 'You are a helpful assistant.'} - ] + prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}] conversation = await manager.get_conversation( sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123' @@ -300,22 +273,17 @@ class TestSessionManagerGetConversation: assert conversation.messages == [] @pytest.mark.asyncio - async def test_prompt_messages_from_config( - self, mock_app_with_config, sample_query, sample_session - ): + async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session): """Test that prompt messages are created from prompt_config.""" sessionmgr = get_session_module() manager = sessionmgr.SessionManager(mock_app_with_config) - prompt_config = [ - {'role': 'system', 'content': 'System message'}, - {'role': 'user', 'content': 'User message'} - ] + prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}] conversation = await manager.get_conversation( sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123' ) assert conversation.prompt.name == 'default' - assert len(conversation.prompt.messages) == 2 \ No newline at end of file + assert len(conversation.prompt.messages) == 2 diff --git a/tests/unit_tests/provider/test_tool_manager.py b/tests/unit_tests/provider/test_tool_manager.py index 2fcf25fb7..0ae33115c 100644 --- a/tests/unit_tests/provider/test_tool_manager.py +++ b/tests/unit_tests/provider/test_tool_manager.py @@ -136,6 +136,7 @@ class TestToolManagerSchemaGeneration: assert 'description' in func assert 'parameters' in func + class TestToolManagerExecuteFuncCall: """Tests for execute_func_call method.""" diff --git a/tests/unit_tests/rag/test_i18n_conversion.py b/tests/unit_tests/rag/test_i18n_conversion.py index a4604e656..203f43454 100644 --- a/tests/unit_tests/rag/test_i18n_conversion.py +++ b/tests/unit_tests/rag/test_i18n_conversion.py @@ -3,6 +3,7 @@ Tests cover: - _to_i18n_name() static method """ + from __future__ import annotations from importlib import import_module @@ -60,4 +61,4 @@ class TestToI18nName: kbmgr = get_kbmgr_module() input_dict = {'en_US': 'English', 'extra_key': 'extra_value'} result = kbmgr.RAGManager._to_i18n_name(input_dict) - assert result == {'en_US': 'English', 'extra_key': 'extra_value'} \ No newline at end of file + assert result == {'en_US': 'English', 'extra_key': 'extra_value'} diff --git a/tests/unit_tests/rag/test_kbmgr.py b/tests/unit_tests/rag/test_kbmgr.py index ae044ebe8..a1a16118d 100644 --- a/tests/unit_tests/rag/test_kbmgr.py +++ b/tests/unit_tests/rag/test_kbmgr.py @@ -6,6 +6,7 @@ Tests cover: - Knowledge engine enrichment - KB loading and removal """ + from __future__ import annotations import pytest @@ -101,13 +102,9 @@ class TestRAGManagerCreateKnowledgeBase: rag_module = get_rag_module() mock_app = create_mock_app() - mock_app.plugin_connector.list_knowledge_engines = AsyncMock( - return_value=[{'plugin_id': 'author/engine'}] - ) + mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}]) mock_app.persistence_mgr.execute_async = AsyncMock() - mock_app.plugin_connector.rag_on_kb_create = AsyncMock( - side_effect=Exception('Plugin error') - ) + mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin error')) manager = rag_module.RAGManager(mock_app) @@ -128,9 +125,7 @@ class TestRAGManagerCreateKnowledgeBase: rag_module = get_rag_module() mock_app = create_mock_app() - mock_app.plugin_connector.list_knowledge_engines = AsyncMock( - return_value=[{'plugin_id': 'author/engine'}] - ) + mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}]) mock_app.persistence_mgr.execute_async = AsyncMock() mock_app.plugin_connector.rag_on_kb_create = AsyncMock() @@ -206,9 +201,7 @@ class TestRuntimeKnowledgeBaseOnKBCreate: mock_app = create_mock_app() mock_kb = create_mock_kb_entity() - mock_app.plugin_connector.rag_on_kb_create = AsyncMock( - side_effect=Exception('Plugin failed') - ) + mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin failed')) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) @@ -245,9 +238,7 @@ class TestRuntimeKnowledgeBaseIngestDocument: mock_app = create_mock_app() mock_kb = create_mock_kb_entity() - mock_app.plugin_connector.call_rag_ingest = AsyncMock( - return_value={'status': 'success'} - ) + mock_app.plugin_connector.call_rag_ingest = AsyncMock(return_value={'status': 'success'}) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) @@ -304,14 +295,10 @@ class TestRAGManagerLoadKnowledgeBasesFromDB: # KB that will cause initialize to fail mock_kb = create_mock_kb_entity() - mock_app.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(all=Mock(return_value=[mock_kb])) - ) + mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb]))) # Make initialize fail by having plugin_connector throw error - mock_app.plugin_connector.rag_on_kb_create = AsyncMock( - side_effect=Exception('Init failed') - ) + mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Init failed')) manager = rag_module.RAGManager(mock_app) # Should not raise - errors are caught @@ -411,9 +398,7 @@ class TestRuntimeKnowledgeBaseRetrieve: mock_kb = create_mock_kb_entity() mock_kb.retrieval_settings = {} - mock_app.plugin_connector.call_rag_retrieve = AsyncMock( - return_value={'results': []} - ) + mock_app.plugin_connector.call_rag_retrieve = AsyncMock(return_value={'results': []}) runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) @@ -682,9 +667,7 @@ class TestRAGManagerGetAllDetails: """Test returns empty list when no knowledge bases.""" rag_module = get_rag_module() mock_app = create_mock_app() - mock_app.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(all=Mock(return_value=[])) - ) + mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[]))) manager = rag_module.RAGManager(mock_app) result = await manager.get_all_knowledge_base_details() @@ -699,9 +682,7 @@ class TestRAGManagerGetAllDetails: # Mock DB result mock_kb_row = Mock() - mock_app.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(all=Mock(return_value=[mock_kb_row])) - ) + mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb_row]))) mock_app.persistence_mgr.serialize_model = Mock( return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'} ) @@ -724,9 +705,7 @@ class TestRAGManagerGetDetails: """Test returns None when KB doesn't exist.""" rag_module = get_rag_module() mock_app = create_mock_app() - mock_app.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(first=Mock(return_value=None)) - ) + mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=None))) manager = rag_module.RAGManager(mock_app) result = await manager.get_knowledge_base_details('nonexistent') @@ -740,9 +719,7 @@ class TestRAGManagerGetDetails: mock_app = create_mock_app() mock_kb_row = Mock() - mock_app.persistence_mgr.execute_async = AsyncMock( - return_value=Mock(first=Mock(return_value=mock_kb_row)) - ) + mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=mock_kb_row))) mock_app.persistence_mgr.serialize_model = Mock( return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'} ) @@ -791,4 +768,4 @@ class TestRAGManagerLoadKnowledgeBase: await manager.load_knowledge_base(kb_dict) - assert 'kb-uuid' in manager.knowledge_bases \ No newline at end of file + assert 'kb-uuid' in manager.knowledge_bases diff --git a/tests/unit_tests/rag/test_runtime_service.py b/tests/unit_tests/rag/test_runtime_service.py index b5c60ccba..650b3bf2f 100644 --- a/tests/unit_tests/rag/test_runtime_service.py +++ b/tests/unit_tests/rag/test_runtime_service.py @@ -121,10 +121,12 @@ class TestRAGRuntimeServiceVectorSearch: """Create mock app.""" mock_app = MagicMock() mock_app.vector_db_mgr = MagicMock() - mock_app.vector_db_mgr.search = AsyncMock(return_value=[ - {'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}}, - {'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}}, - ]) + mock_app.vector_db_mgr.search = AsyncMock( + return_value=[ + {'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}}, + {'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}}, + ] + ) return mock_app def _make_rag_import_mocks(self): @@ -301,10 +303,7 @@ class TestRAGRuntimeServiceVectorList: mock_app = MagicMock() mock_app.vector_db_mgr = MagicMock() mock_app.vector_db_mgr.list_by_filter = AsyncMock( - return_value=( - [{'id': 'id1', 'metadata': {'file_id': 'abc'}}], - 10 - ) + return_value=([{'id': 'id1', 'metadata': {'file_id': 'abc'}}], 10) ) return mock_app diff --git a/tests/unit_tests/storage/test_localstorage_path_traversal.py b/tests/unit_tests/storage/test_localstorage_path_traversal.py index 8c5ebf527..5e950eb32 100644 --- a/tests/unit_tests/storage/test_localstorage_path_traversal.py +++ b/tests/unit_tests/storage/test_localstorage_path_traversal.py @@ -21,8 +21,8 @@ from langbot.pkg.storage.providers.localstorage import LocalStorageProvider @pytest.fixture def storage_provider(tmp_path): """Create a LocalStorageProvider with a temporary storage path.""" - storage_path = str(tmp_path / "storage") - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + storage_path = str(tmp_path / 'storage') + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): mock_app = Mock() provider = LocalStorageProvider(mock_app) yield provider, storage_path @@ -35,15 +35,15 @@ class TestPathTraversalPrevention: async def test_absolute_path_save_rejected(self, storage_provider, tmp_path): """Saving with an absolute path key must be blocked.""" provider, storage_path = storage_provider - target_file = str(tmp_path / "pwned.txt") + target_file = str(tmp_path / 'pwned.txt') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with pytest.raises((ValueError, PermissionError)): - await provider.save(target_file, b"malicious content") + await provider.save(target_file, b'malicious content') # The file must NOT exist outside the storage directory assert not os.path.exists(target_file), ( - f"Path traversal succeeded: file was written outside storage to {target_file}" + f'Path traversal succeeded: file was written outside storage to {target_file}' ) @pytest.mark.asyncio @@ -52,32 +52,28 @@ class TestPathTraversalPrevention: provider, storage_path = storage_provider # Create a file outside the storage directory - target_file = str(tmp_path / "secret.txt") - with open(target_file, "wb") as f: - f.write(b"secret data") + target_file = str(tmp_path / 'secret.txt') + with open(target_file, 'wb') as f: + f.write(b'secret data') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with pytest.raises((ValueError, PermissionError, FileNotFoundError)): data = await provider.load(target_file) - assert data != b"secret data", ( - "Path traversal succeeded: read file outside storage" - ) + assert data != b'secret data', 'Path traversal succeeded: read file outside storage' @pytest.mark.asyncio async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path): """Exists check with an absolute path key must be blocked or return False.""" provider, storage_path = storage_provider - target_file = str(tmp_path / "check_me.txt") - with open(target_file, "wb") as f: - f.write(b"data") + target_file = str(tmp_path / 'check_me.txt') + with open(target_file, 'wb') as f: + f.write(b'data') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): try: result = await provider.exists(target_file) - assert result is False, ( - "Path traversal succeeded: exists() returned True for file outside storage" - ) + assert result is False, 'Path traversal succeeded: exists() returned True for file outside storage' except (ValueError, PermissionError): pass # Expected @@ -86,28 +82,26 @@ class TestPathTraversalPrevention: """Deleting with an absolute path key must be blocked.""" provider, storage_path = storage_provider - target_file = str(tmp_path / "do_not_delete.txt") - with open(target_file, "wb") as f: - f.write(b"important data") + target_file = str(tmp_path / 'do_not_delete.txt') + with open(target_file, 'wb') as f: + f.write(b'important data') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with pytest.raises((ValueError, PermissionError, FileNotFoundError)): await provider.delete(target_file) - assert os.path.exists(target_file), ( - "Path traversal succeeded: file outside storage was deleted" - ) + assert os.path.exists(target_file), 'Path traversal succeeded: file outside storage was deleted' @pytest.mark.asyncio async def test_absolute_path_size_rejected(self, storage_provider, tmp_path): """Size check with an absolute path key must be blocked.""" provider, storage_path = storage_provider - target_file = str(tmp_path / "measure_me.txt") - with open(target_file, "wb") as f: - f.write(b"some data") + target_file = str(tmp_path / 'measure_me.txt') + with open(target_file, 'wb') as f: + f.write(b'some data') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with pytest.raises((ValueError, PermissionError, FileNotFoundError)): await provider.size(target_file) @@ -116,41 +110,39 @@ class TestPathTraversalPrevention: """Relative path traversal with '..' must be blocked.""" provider, storage_path = storage_provider - target_file = str(tmp_path / "above_storage.txt") - with open(target_file, "wb") as f: - f.write(b"above storage secret") + target_file = str(tmp_path / 'above_storage.txt') + with open(target_file, 'wb') as f: + f.write(b'above storage secret') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): - relative_key = os.path.join("..", "above_storage.txt") + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): + relative_key = os.path.join('..', 'above_storage.txt') with pytest.raises((ValueError, PermissionError, FileNotFoundError)): data = await provider.load(relative_key) - assert data != b"above storage secret" + assert data != b'above storage secret' @pytest.mark.asyncio async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path): """delete_dir_recursive with traversal path must be blocked.""" provider, storage_path = storage_provider - outside_dir = tmp_path / "outside_dir" + outside_dir = tmp_path / 'outside_dir' outside_dir.mkdir() - (outside_dir / "file.txt").write_text("important") + (outside_dir / 'file.txt').write_text('important') - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): with pytest.raises((ValueError, PermissionError)): await provider.delete_dir_recursive(str(outside_dir)) - assert outside_dir.exists(), ( - "Path traversal succeeded: directory outside storage was deleted" - ) + assert outside_dir.exists(), 'Path traversal succeeded: directory outside storage was deleted' @pytest.mark.asyncio async def test_legitimate_key_works(self, storage_provider): """Normal keys without traversal must still work.""" provider, storage_path = storage_provider - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): - key = "test_image_abc123.png" - content = b"PNG image data" + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): + key = 'test_image_abc123.png' + content = b'PNG image data' await provider.save(key, content) assert await provider.exists(key) is True @@ -166,9 +158,9 @@ class TestPathTraversalPrevention: """Keys with legitimate subdirectories must still work.""" provider, storage_path = storage_provider - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): - key = "bot_log_images/img_001.png" - content = b"PNG image data" + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): + key = 'bot_log_images/img_001.png' + content = b'PNG image data' await provider.save(key, content) assert await provider.exists(key) is True @@ -181,33 +173,33 @@ class TestPathTraversalPrevention: """delete_dir_recursive should handle non-existing directories gracefully.""" provider, storage_path = storage_provider - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): # Try to delete a non-existing directory - should not raise - await provider.delete_dir_recursive("nonexistent_dir") + await provider.delete_dir_recursive('nonexistent_dir') @pytest.mark.asyncio async def test_delete_dir_recursive_with_files(self, storage_provider): """delete_dir_recursive should delete directory with files inside.""" provider, storage_path = storage_provider - with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path): + with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path): # Create a directory with files - key1 = "test_dir/file1.txt" - key2 = "test_dir/file2.txt" - await provider.save(key1, b"content1") - await provider.save(key2, b"content2") + key1 = 'test_dir/file1.txt' + key2 = 'test_dir/file2.txt' + await provider.save(key1, b'content1') + await provider.save(key2, b'content2') # Verify files exist assert await provider.exists(key1) assert await provider.exists(key2) # Delete directory recursively - await provider.delete_dir_recursive("test_dir") + await provider.delete_dir_recursive('test_dir') # Verify files no longer exist assert not await provider.exists(key1) assert not await provider.exists(key2) -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/storage/test_s3storage.py b/tests/unit_tests/storage/test_s3storage.py index 20bf6f00e..eb3d0f7e6 100644 --- a/tests/unit_tests/storage/test_s3storage.py +++ b/tests/unit_tests/storage/test_s3storage.py @@ -8,6 +8,7 @@ Tests cover: Uses moto library to mock AWS S3 service. """ + from __future__ import annotations import pytest @@ -44,8 +45,10 @@ def mock_app_with_s3_config(): def s3_mock(): """Set up moto S3 mock context.""" from moto import mock_aws + with mock_aws(): import boto3 + # Create bucket for tests that need pre-existing bucket s3 = boto3.client('s3', region_name='us-east-1') yield s3 @@ -325,4 +328,4 @@ class TestS3StorageProviderErrorHandling: await provider.initialize() with pytest.raises(Exception): - await provider.size('nonexistent.txt') \ No newline at end of file + await provider.size('nonexistent.txt') diff --git a/tests/unit_tests/storage/test_storage_manager.py b/tests/unit_tests/storage/test_storage_manager.py index c0b64cae4..d96f1cb04 100644 --- a/tests/unit_tests/storage/test_storage_manager.py +++ b/tests/unit_tests/storage/test_storage_manager.py @@ -31,7 +31,7 @@ class TestStorageMgr: storage_mgr = StorageMgr(mock_app) - with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock): + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): await storage_mgr.initialize() assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) mock_app.logger.info.assert_called() @@ -41,12 +41,12 @@ class TestStorageMgr: """Should use local storage when explicitly configured.""" mock_app = Mock() mock_app.instance_config = Mock() - mock_app.instance_config.data = {"storage": {"use": "local"}} + mock_app.instance_config.data = {'storage': {'use': 'local'}} mock_app.logger = Mock() storage_mgr = StorageMgr(mock_app) - with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock): + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): await storage_mgr.initialize() assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) @@ -55,14 +55,12 @@ class TestStorageMgr: """Should use S3 storage when configured.""" mock_app = Mock() mock_app.instance_config = Mock() - mock_app.instance_config.data = { - "storage": {"use": "s3", "s3": {"endpoint_url": "https://s3.amazonaws.com"}} - } + mock_app.instance_config.data = {'storage': {'use': 's3', 's3': {'endpoint_url': 'https://s3.amazonaws.com'}}} mock_app.logger = Mock() storage_mgr = StorageMgr(mock_app) - with patch.object(S3StorageProvider, "initialize", new_callable=AsyncMock): + with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock): await storage_mgr.initialize() assert isinstance(storage_mgr.storage_provider, S3StorageProvider) @@ -71,12 +69,12 @@ class TestStorageMgr: """Should default to local storage for invalid storage type.""" mock_app = Mock() mock_app.instance_config = Mock() - mock_app.instance_config.data = {"storage": {"use": "invalid_type"}} + mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}} mock_app.logger = Mock() storage_mgr = StorageMgr(mock_app) - with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock): + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock): await storage_mgr.initialize() assert isinstance(storage_mgr.storage_provider, LocalStorageProvider) @@ -90,9 +88,7 @@ class TestStorageMgr: storage_mgr = StorageMgr(mock_app) - with patch.object( - LocalStorageProvider, "initialize", new_callable=AsyncMock - ) as mock_init: + with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init: await storage_mgr.initialize() mock_init.assert_called_once() @@ -105,8 +101,8 @@ class TestStorageProviderBase: mock_app = Mock() # Use LocalStorageProvider as concrete implementation - with patch("os.path.exists", return_value=True): - with patch("os.makedirs"): + with patch('os.path.exists', return_value=True): + with patch('os.makedirs'): provider = LocalStorageProvider(mock_app) assert provider.ap == mock_app @@ -115,12 +111,12 @@ class TestStorageProviderBase: """Provider base initialize should be callable and do nothing.""" mock_app = Mock() - with patch("os.path.exists", return_value=True): - with patch("os.makedirs"): + with patch('os.path.exists', return_value=True): + with patch('os.makedirs'): provider = LocalStorageProvider(mock_app) # Initialize should not raise await provider.initialize() -if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/telemetry/test_telemetry.py b/tests/unit_tests/telemetry/test_telemetry.py index 2ceb1f09c..b15a989ee 100644 --- a/tests/unit_tests/telemetry/test_telemetry.py +++ b/tests/unit_tests/telemetry/test_telemetry.py @@ -8,6 +8,7 @@ Tests cover: - HTTP request success/failure scenarios - Source code bug: send_tasks should be instance variable """ + from __future__ import annotations import pytest @@ -38,6 +39,7 @@ class TestTelemetryManagerInit: manager = telemetry.TelemetryManager(mock_app) assert manager.telemetry_config == {} + class TestTelemetryManagerInitialize: """Tests for initialize() method.""" @@ -218,7 +220,7 @@ class TestPayloadSanitization: # All null string fields should be empty strings for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']: - assert result[field] == '', f"Field {field} should be empty string, got {result[field]}" + assert result[field] == '', f'Field {field} should be empty string, got {result[field]}' @pytest.mark.asyncio async def test_sanitize_string_fields_preserve_values(self): @@ -418,9 +420,7 @@ class TestHTTPScenarios: manager.telemetry_config = {'url': 'https://example.com'} mock_response = Mock( - status_code=200, - text='{"code": 0, "msg": "success"}', - json=Mock(return_value={'code': 0, 'msg': 'success'}) + status_code=200, text='{"code": 0, "msg": "success"}', json=Mock(return_value={'code': 0, 'msg': 'success'}) ) mock_client = Mock() @@ -448,9 +448,7 @@ class TestHTTPScenarios: manager.telemetry_config = {'url': 'https://example.com'} mock_response = Mock( - status_code=500, - text='Internal Server Error', - json=Mock(return_value={'code': 500, 'msg': 'error'}) + status_code=500, text='Internal Server Error', json=Mock(return_value={'code': 500, 'msg': 'error'}) ) mock_client = Mock() @@ -478,7 +476,7 @@ class TestHTTPScenarios: mock_response = Mock( status_code=200, text='{"code": 400, "msg": "Bad Request"}', - json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}) + json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}), ) mock_client = Mock() @@ -493,7 +491,7 @@ class TestHTTPScenarios: assert mock_app.logger.warning.call_count >= 1 # Check that one of the calls contains application error info all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list] - assert any('400' in w for w in all_warnings), f"No warning contained error code 400: {all_warnings}" + assert any('400' in w for w in all_warnings), f'No warning contained error code 400: {all_warnings}' @pytest.mark.asyncio async def test_send_timeout_logs_warning(self): diff --git a/tests/unit_tests/utils/test_funcschema.py b/tests/unit_tests/utils/test_funcschema.py index c2b3bffe0..2d9e2d575 100644 --- a/tests/unit_tests/utils/test_funcschema.py +++ b/tests/unit_tests/utils/test_funcschema.py @@ -9,6 +9,7 @@ Tests cover: Note: Do NOT use 'from __future__ import annotations' because funcschema.py expects actual type objects, not string annotations. """ + import pytest from importlib import import_module diff --git a/tests/unit_tests/utils/test_image.py b/tests/unit_tests/utils/test_image.py index 291ba8c07..4a42717ba 100644 --- a/tests/unit_tests/utils/test_image.py +++ b/tests/unit_tests/utils/test_image.py @@ -20,55 +20,53 @@ class TestGetQQImageDownloadableUrl: def test_basic_url(self): """Parse basic image URL.""" - url = "http://example.com/image.jpg" + url = 'http://example.com/image.jpg' result_url, query = get_qq_image_downloadable_url(url) - assert result_url == "http://example.com/image.jpg" + assert result_url == 'http://example.com/image.jpg' assert query == {} def test_url_with_query_params(self): """Parse URL with query parameters.""" - url = "http://example.com/image.jpg?param1=value1¶m2=value2" + url = 'http://example.com/image.jpg?param1=value1¶m2=value2' result_url, query = get_qq_image_downloadable_url(url) - assert result_url == "http://example.com/image.jpg" - assert query == {"param1": ["value1"], "param2": ["value2"]} + assert result_url == 'http://example.com/image.jpg' + assert query == {'param1': ['value1'], 'param2': ['value2']} def test_url_with_port(self): """Parse URL with port number.""" - url = "http://example.com:8080/image.jpg" + url = 'http://example.com:8080/image.jpg' result_url, query = get_qq_image_downloadable_url(url) - assert result_url == "http://example.com:8080/image.jpg" + assert result_url == 'http://example.com:8080/image.jpg' def test_url_with_path(self): """Parse URL with complex path.""" - url = "http://example.com/path/to/image.jpg" + url = 'http://example.com/path/to/image.jpg' result_url, query = get_qq_image_downloadable_url(url) - assert result_url == "http://example.com/path/to/image.jpg" + assert result_url == 'http://example.com/path/to/image.jpg' def test_url_with_fragment(self): """Parse URL with fragment (fragment is not part of query).""" - url = "http://example.com/image.jpg#fragment" + url = 'http://example.com/image.jpg#fragment' result_url, query = get_qq_image_downloadable_url(url) # Fragment is not included in query string parsing - assert "http://example.com/image.jpg" in result_url + assert 'http://example.com/image.jpg' in result_url def test_https_url(self): """Parse HTTPS URL and preserve its scheme.""" - url = "https://example.com/image.jpg" + url = 'https://example.com/image.jpg' result_url, query = get_qq_image_downloadable_url(url) - assert result_url == "https://example.com/image.jpg" + assert result_url == 'https://example.com/image.jpg' assert query == {} def test_preserves_qq_https_scheme_and_query(self): """QQ image URLs keep HTTPS and query parameters.""" - result_url, query = get_qq_image_downloadable_url( - 'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1' - ) + result_url, query = get_qq_image_downloadable_url('https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1') assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0' assert query == {'term': ['2'], 'is_origin': ['1']} @@ -88,50 +86,50 @@ class TestExtractB64AndFormat: async def test_jpeg_data_uri(self): """Extract base64 and format from JPEG data URI.""" # Create a simple base64 string - original_data = b"test image data" + original_data = b'test image data' b64_data = base64.b64encode(original_data).decode() - data_uri = f"data:image/jpeg;base64,{b64_data}" + data_uri = f'data:image/jpeg;base64,{b64_data}' result_b64, result_format = await extract_b64_and_format(data_uri) assert result_b64 == b64_data - assert result_format == "jpeg" + assert result_format == 'jpeg' @pytest.mark.asyncio async def test_png_data_uri(self): """Extract base64 and format from PNG data URI.""" - original_data = b"test png data" + original_data = b'test png data' b64_data = base64.b64encode(original_data).decode() - data_uri = f"data:image/png;base64,{b64_data}" + data_uri = f'data:image/png;base64,{b64_data}' result_b64, result_format = await extract_b64_and_format(data_uri) assert result_b64 == b64_data - assert result_format == "png" + assert result_format == 'png' @pytest.mark.asyncio async def test_gif_data_uri(self): """Extract base64 and format from GIF data URI.""" - original_data = b"test gif data" + original_data = b'test gif data' b64_data = base64.b64encode(original_data).decode() - data_uri = f"data:image/gif;base64,{b64_data}" + data_uri = f'data:image/gif;base64,{b64_data}' result_b64, result_format = await extract_b64_and_format(data_uri) assert result_b64 == b64_data - assert result_format == "gif" + assert result_format == 'gif' @pytest.mark.asyncio async def test_webp_data_uri(self): """Extract base64 and format from WebP data URI.""" - original_data = b"test webp data" + original_data = b'test webp data' b64_data = base64.b64encode(original_data).decode() - data_uri = f"data:image/webp;base64,{b64_data}" + data_uri = f'data:image/webp;base64,{b64_data}' result_b64, result_format = await extract_b64_and_format(data_uri) assert result_b64 == b64_data - assert result_format == "webp" + assert result_format == 'webp' @pytest.mark.asyncio async def test_complex_base64(self): @@ -139,7 +137,7 @@ class TestExtractB64AndFormat: # Base64 can include + and / characters original_data = bytes(range(256)) # All byte values b64_data = base64.b64encode(original_data).decode() - data_uri = f"data:image/png;base64,{b64_data}" + data_uri = f'data:image/png;base64,{b64_data}' result_b64, result_format = await extract_b64_and_format(data_uri) @@ -150,9 +148,9 @@ class TestExtractB64AndFormat: @pytest.mark.asyncio async def test_empty_base64(self): """Handle empty base64 string.""" - data_uri = "data:image/png;base64," + data_uri = 'data:image/png;base64,' result_b64, result_format = await extract_b64_and_format(data_uri) - assert result_b64 == "" - assert result_format == "png" + assert result_b64 == '' + assert result_format == 'png' diff --git a/tests/unit_tests/utils/test_importutil.py b/tests/unit_tests/utils/test_importutil.py index b0ea0ad7a..bf0e4e050 100644 --- a/tests/unit_tests/utils/test_importutil.py +++ b/tests/unit_tests/utils/test_importutil.py @@ -23,52 +23,52 @@ class TestImportDir: def test_calls_importlib_for_each_python_file(self, tmp_path): """Should call importlib.import_module for each .py file.""" - module_dir = tmp_path / "test_modules" + module_dir = tmp_path / 'test_modules' module_dir.mkdir() - (module_dir / "__init__.py").write_text("") - (module_dir / "module_a.py").write_text("VALUE_A = 'a'\n") - (module_dir / "module_b.py").write_text("VALUE_B = 'b'\n") - (module_dir / "readme.txt").write_text("not a module") + (module_dir / '__init__.py').write_text('') + (module_dir / 'module_a.py').write_text("VALUE_A = 'a'\n") + (module_dir / 'module_b.py').write_text("VALUE_B = 'b'\n") + (module_dir / 'readme.txt').write_text('not a module') from langbot.pkg.utils import importutil - with patch.object(importlib, "import_module") as mock_import: - importutil.import_dir(str(module_dir), path_prefix="test_prefix.") + with patch.object(importlib, 'import_module') as mock_import: + importutil.import_dir(str(module_dir), path_prefix='test_prefix.') # Should call import_module for each .py file (excluding __init__.py) assert mock_import.call_count == 2 def test_skips_init_py(self, tmp_path): """Should skip __init__.py when importing.""" - module_dir = tmp_path / "test_modules" + module_dir = tmp_path / 'test_modules' module_dir.mkdir() - (module_dir / "__init__.py").write_text("") - (module_dir / "regular.py").write_text("VALUE = 1\n") + (module_dir / '__init__.py').write_text('') + (module_dir / 'regular.py').write_text('VALUE = 1\n') from langbot.pkg.utils import importutil - with patch.object(importlib, "import_module") as mock_import: - importutil.import_dir(str(module_dir), path_prefix="test_prefix.") + with patch.object(importlib, 'import_module') as mock_import: + importutil.import_dir(str(module_dir), path_prefix='test_prefix.') # __init__.py should be skipped mock_import.assert_called_once() # The call should not include __init__ call_args = mock_import.call_args[0][0] - assert "__init__" not in call_args + assert '__init__' not in call_args def test_ignores_non_py_files(self, tmp_path): """Should ignore non-.py files.""" - module_dir = tmp_path / "test_modules" + module_dir = tmp_path / 'test_modules' module_dir.mkdir() - (module_dir / "module.py").write_text("VALUE = 1\n") - (module_dir / "readme.txt").write_text("text") - (module_dir / "data.json").write_text("{}") + (module_dir / 'module.py').write_text('VALUE = 1\n') + (module_dir / 'readme.txt').write_text('text') + (module_dir / 'data.json').write_text('{}') from langbot.pkg.utils import importutil - with patch.object(importlib, "import_module") as mock_import: - importutil.import_dir(str(module_dir), path_prefix="test_prefix.") + with patch.object(importlib, 'import_module') as mock_import: + importutil.import_dir(str(module_dir), path_prefix='test_prefix.') # Only .py files should be imported assert mock_import.call_count == 1 @@ -79,14 +79,14 @@ class TestImportModulesInPkg: def test_imports_modules_from_package(self, tmp_path): """Should import all modules from a package object.""" mock_pkg = MagicMock() - mock_pkg.__file__ = str(tmp_path / "__init__.py") + mock_pkg.__file__ = str(tmp_path / '__init__.py') - (tmp_path / "__init__.py").write_text("") - (tmp_path / "mod1.py").write_text("MOD1 = 1\n") + (tmp_path / '__init__.py').write_text('') + (tmp_path / 'mod1.py').write_text('MOD1 = 1\n') from langbot.pkg.utils import importutil - with patch.object(importutil, "import_dir") as mock_import_dir: + with patch.object(importutil, 'import_dir') as mock_import_dir: importutil.import_modules_in_pkg(mock_pkg) mock_import_dir.assert_called_once() call_path = mock_import_dir.call_args[0][0] @@ -101,11 +101,11 @@ class TestImportModulesInPkgs: from langbot.pkg.utils import importutil mock_pkg1 = MagicMock() - mock_pkg1.__file__ = "/path/to/pkg1/__init__.py" + mock_pkg1.__file__ = '/path/to/pkg1/__init__.py' mock_pkg2 = MagicMock() - mock_pkg2.__file__ = "/path/to/pkg2/__init__.py" + mock_pkg2.__file__ = '/path/to/pkg2/__init__.py' - with patch.object(importutil, "import_modules_in_pkg") as mock_import: + with patch.object(importutil, 'import_modules_in_pkg') as mock_import: importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2]) assert mock_import.call_count == 2 @@ -116,18 +116,18 @@ class TestImportDotStyleDir: def test_converts_dot_notation_to_path(self, tmp_path): """Should convert dot notation to path and import.""" # Create structure matching the dot notation - (tmp_path / "my").mkdir() - (tmp_path / "my" / "pkg").mkdir() - (tmp_path / "my" / "pkg" / "test").mkdir() + (tmp_path / 'my').mkdir() + (tmp_path / 'my' / 'pkg').mkdir() + (tmp_path / 'my' / 'pkg' / 'test').mkdir() from langbot.pkg.utils import importutil - with patch.object(importutil, "import_dir") as mock_import_dir: - importutil.import_dot_style_dir("my.pkg.test") + with patch.object(importutil, 'import_dir') as mock_import_dir: + importutil.import_dot_style_dir('my.pkg.test') # The path should be converted using os.path.join call_path = mock_import_dir.call_args[0][0] # Should contain the path components joined - assert "my" in call_path + assert 'my' in call_path class TestReadResourceFile: @@ -137,16 +137,16 @@ class TestReadResourceFile: """Should read content from a resource file.""" from langbot.pkg.utils import importutil - content = importutil.read_resource_file("templates/config.yaml") - assert "admins:" in content - assert "edition: community" in content + content = importutil.read_resource_file('templates/config.yaml') + assert 'admins:' in content + assert 'edition: community' in content def test_raises_for_nonexistent_file(self): """Should raise exception for non-existent resource file.""" from langbot.pkg.utils import importutil with pytest.raises((FileNotFoundError, Exception)): - importutil.read_resource_file("nonexistent/path/file.txt") + importutil.read_resource_file('nonexistent/path/file.txt') class TestReadResourceFileBytes: @@ -156,16 +156,16 @@ class TestReadResourceFileBytes: """Should read content as bytes from a resource file.""" from langbot.pkg.utils import importutil - content = importutil.read_resource_file_bytes("templates/config.yaml") - assert b"admins:" in content - assert b"edition: community" in content + content = importutil.read_resource_file_bytes('templates/config.yaml') + assert b'admins:' in content + assert b'edition: community' in content def test_raises_for_nonexistent_file_bytes(self): """Should raise exception for non-existent resource file.""" from langbot.pkg.utils import importutil with pytest.raises((FileNotFoundError, Exception)): - importutil.read_resource_file_bytes("nonexistent/path/file.txt") + importutil.read_resource_file_bytes('nonexistent/path/file.txt') class TestListResourceFiles: @@ -175,9 +175,9 @@ class TestListResourceFiles: """Should list files in a resource directory.""" from langbot.pkg.utils import importutil - files = importutil.list_resource_files("templates") - assert "config.yaml" in files - assert "default-pipeline-config.json" in files + files = importutil.list_resource_files('templates') + assert 'config.yaml' in files + assert 'default-pipeline-config.json' in files assert all(isinstance(file, str) for file in files) def test_raises_for_nonexistent_directory(self): @@ -185,8 +185,8 @@ class TestListResourceFiles: from langbot.pkg.utils import importutil with pytest.raises((FileNotFoundError, Exception)): - importutil.list_resource_files("nonexistent_directory_xyz") + importutil.list_resource_files('nonexistent_directory_xyz') -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/utils/test_platform.py b/tests/unit_tests/utils/test_platform.py index 76a64a052..4f3e1a5da 100644 --- a/tests/unit_tests/utils/test_platform.py +++ b/tests/unit_tests/utils/test_platform.py @@ -5,6 +5,7 @@ Tests cover: - Docker environment detection - WebSocket plugin runtime mode """ + from __future__ import annotations import os @@ -86,4 +87,4 @@ class TestGetPlatform: assert platform_module.use_websocket_to_connect_plugin_runtime() is True # Restore - platform_module.standalone_runtime = original \ No newline at end of file + platform_module.standalone_runtime = original diff --git a/tests/unit_tests/utils/test_proxy.py b/tests/unit_tests/utils/test_proxy.py index 572375194..09bc44cc5 100644 --- a/tests/unit_tests/utils/test_proxy.py +++ b/tests/unit_tests/utils/test_proxy.py @@ -60,10 +60,12 @@ class TestProxyManager: async def test_initialize_config_overrides_env(self): """Config proxy overrides environment variables.""" - mock_app = self._create_mock_app(proxy_config={ - 'http': 'http://config-proxy:8080', - 'https': 'https://config-proxy:8443', - }) + mock_app = self._create_mock_app( + proxy_config={ + 'http': 'http://config-proxy:8080', + 'https': 'https://config-proxy:8443', + } + ) with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}): pm = ProxyManager(mock_app) @@ -74,10 +76,12 @@ class TestProxyManager: async def test_initialize_sets_env_variables(self): """initialize sets proxy to environment variables.""" - mock_app = self._create_mock_app(proxy_config={ - 'http': 'http://test-proxy:8080', - 'https': 'https://test-proxy:8443', - }) + mock_app = self._create_mock_app( + proxy_config={ + 'http': 'http://test-proxy:8080', + 'https': 'https://test-proxy:8443', + } + ) pm = ProxyManager(mock_app) await pm.initialize() @@ -143,9 +147,11 @@ class TestProxyManager: async def test_initialize_http_only_config(self): """initialize handles http-only config.""" - mock_app = self._create_mock_app(proxy_config={ - 'http': 'http://http-only:8080', - }) + mock_app = self._create_mock_app( + proxy_config={ + 'http': 'http://http-only:8080', + } + ) # Clear any existing proxy env vars env_backup = {} diff --git a/tests/unit_tests/utils/test_runner.py b/tests/unit_tests/utils/test_runner.py index 28f5d8e52..5fc092cf2 100644 --- a/tests/unit_tests/utils/test_runner.py +++ b/tests/unit_tests/utils/test_runner.py @@ -29,63 +29,63 @@ class TestGetRunnerCategory: def test_empty_url_returns_unknown(self): """Empty or None URL should return UNKNOWN.""" - assert get_runner_category("test", "") == RunnerCategory.UNKNOWN - assert get_runner_category("test", None) == RunnerCategory.UNKNOWN + assert get_runner_category('test', '') == RunnerCategory.UNKNOWN + assert get_runner_category('test', None) == RunnerCategory.UNKNOWN def test_localhost_returns_local(self): """localhost URL should be categorized as LOCAL.""" - assert get_runner_category("test", "http://localhost:3000") == RunnerCategory.LOCAL - assert get_runner_category("test", "https://localhost") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://localhost:3000') == RunnerCategory.LOCAL + assert get_runner_category('test', 'https://localhost') == RunnerCategory.LOCAL def test_127_0_0_1_returns_local(self): """127.0.0.1 URL should be categorized as LOCAL.""" - assert get_runner_category("test", "http://127.0.0.1:8080") == RunnerCategory.LOCAL - assert get_runner_category("test", "https://127.0.0.1") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://127.0.0.1:8080') == RunnerCategory.LOCAL + assert get_runner_category('test', 'https://127.0.0.1') == RunnerCategory.LOCAL def test_0_0_0_0_returns_local(self): """0.0.0.0 URL should be categorized as LOCAL.""" - assert get_runner_category("test", "http://0.0.0.0:8080") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://0.0.0.0:8080') == RunnerCategory.LOCAL def test_private_ip_192_168_returns_local(self): """192.168.x.x private IP should be categorized as LOCAL.""" - assert get_runner_category("test", "http://192.168.1.1:3000") == RunnerCategory.LOCAL - assert get_runner_category("test", "http://192.168.0.100") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://192.168.1.1:3000') == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://192.168.0.100') == RunnerCategory.LOCAL def test_private_ip_10_returns_local(self): """10.x.x.x private IP should be categorized as LOCAL.""" - assert get_runner_category("test", "http://10.0.0.1:8080") == RunnerCategory.LOCAL - assert get_runner_category("test", "http://10.255.255.255") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://10.0.0.1:8080') == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://10.255.255.255') == RunnerCategory.LOCAL def test_private_ip_172_16_31_returns_local(self): """172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL.""" - assert get_runner_category("test", "http://172.16.0.1:8080") == RunnerCategory.LOCAL - assert get_runner_category("test", "http://172.20.0.1") == RunnerCategory.LOCAL - assert get_runner_category("test", "http://172.31.255.255") == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://172.16.0.1:8080') == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://172.20.0.1') == RunnerCategory.LOCAL + assert get_runner_category('test', 'http://172.31.255.255') == RunnerCategory.LOCAL def test_n8n_cloud_returns_cloud(self): """n8n.cloud domain should be categorized as CLOUD.""" - assert get_runner_category("test", "https://myinstance.n8n.cloud") == RunnerCategory.CLOUD - assert get_runner_category("test", "https://test.n8n.io") == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://myinstance.n8n.cloud') == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://test.n8n.io') == RunnerCategory.CLOUD def test_dify_cloud_returns_cloud(self): """Dify cloud domains should be categorized as CLOUD.""" - assert get_runner_category("test", "https://api.dify.ai/v1") == RunnerCategory.CLOUD - assert get_runner_category("test", "https://cloud.dify.ai") == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://api.dify.ai/v1') == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://cloud.dify.ai') == RunnerCategory.CLOUD def test_coze_cloud_returns_cloud(self): """Coze domains should be categorized as CLOUD.""" - assert get_runner_category("test", "https://api.coze.com") == RunnerCategory.CLOUD - assert get_runner_category("test", "https://api.coze.cn") == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://api.coze.com') == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://api.coze.cn') == RunnerCategory.CLOUD def test_langflow_cloud_returns_cloud(self): """Langflow domains should be categorized as CLOUD.""" - assert get_runner_category("test", "https://cloud.langflow.ai") == RunnerCategory.CLOUD - assert get_runner_category("test", "https://test.langflow.org") == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://cloud.langflow.ai') == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://test.langflow.org') == RunnerCategory.CLOUD def test_other_url_returns_cloud(self): """Other URLs should default to CLOUD category.""" - assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD - assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://example.com') == RunnerCategory.CLOUD + assert get_runner_category('test', 'https://myserver.example.org') == RunnerCategory.CLOUD @pytest.mark.parametrize( 'runner_url', @@ -101,7 +101,7 @@ class TestGetRunnerCategory: ) def test_invalid_urls_return_unknown(self, runner_url): """Invalid or incomplete URLs should return UNKNOWN.""" - assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN + assert get_runner_category('test', runner_url) == RunnerCategory.UNKNOWN def test_urlparse_exception_returns_unknown(self): """Exception during URL parsing should return UNKNOWN.""" @@ -109,15 +109,15 @@ class TestGetRunnerCategory: from langbot.pkg.utils import runner def mock_urlparse(url): - raise Exception("URL parsing failed") + raise Exception('URL parsing failed') - with patch("langbot.pkg.utils.runner.urlparse", side_effect=mock_urlparse): - result = runner.get_runner_category("test", "http://example.com") + with patch('langbot.pkg.utils.runner.urlparse', side_effect=mock_urlparse): + result = runner.get_runner_category('test', 'http://example.com') assert result == RunnerCategory.UNKNOWN def test_url_without_scheme_returns_unknown(self): """URL without scheme should return UNKNOWN.""" - assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN + assert get_runner_category('test', 'example.com') == RunnerCategory.UNKNOWN @pytest.mark.parametrize( 'runner_url', @@ -146,20 +146,21 @@ class TestGetRunnerCategory: """Domain names that only look like private IP prefixes should not be LOCAL.""" assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD + class TestIsCloudRunner: """Test is_cloud_runner helper function.""" def test_cloud_runner_returns_true(self): """Cloud URL should return True.""" - assert is_cloud_runner("test", "https://api.dify.ai") is True + assert is_cloud_runner('test', 'https://api.dify.ai') is True def test_local_runner_returns_false(self): """Local URL should return False.""" - assert is_cloud_runner("test", "http://localhost:3000") is False + assert is_cloud_runner('test', 'http://localhost:3000') is False def test_unknown_returns_false(self): """Unknown category should return False.""" - assert is_cloud_runner("test", None) is False + assert is_cloud_runner('test', None) is False class TestIsLocalRunner: @@ -167,15 +168,15 @@ class TestIsLocalRunner: def test_local_runner_returns_true(self): """Local URL should return True.""" - assert is_local_runner("test", "http://localhost:3000") is True + assert is_local_runner('test', 'http://localhost:3000') is True def test_cloud_runner_returns_false(self): """Cloud URL should return False.""" - assert is_local_runner("test", "https://api.dify.ai") is False + assert is_local_runner('test', 'https://api.dify.ai') is False def test_unknown_returns_false(self): """Unknown category should return False.""" - assert is_local_runner("test", None) is False + assert is_local_runner('test', None) is False class TestGetRunnerInfo: @@ -183,17 +184,17 @@ class TestGetRunnerInfo: def test_returns_dict_with_expected_keys(self): """Should return dict with name, url, and category keys.""" - info = get_runner_info("my-runner", "http://localhost:3000") - assert "name" in info - assert "url" in info - assert "category" in info + info = get_runner_info('my-runner', 'http://localhost:3000') + assert 'name' in info + assert 'url' in info + assert 'category' in info def test_includes_correct_values(self): """Should include correct values in dict.""" - info = get_runner_info("my-runner", "http://localhost:3000") - assert info["name"] == "my-runner" - assert info["url"] == "http://localhost:3000" - assert info["category"] == RunnerCategory.LOCAL + info = get_runner_info('my-runner', 'http://localhost:3000') + assert info['name'] == 'my-runner' + assert info['url'] == 'http://localhost:3000' + assert info['category'] == RunnerCategory.LOCAL class TestExtractRunnerUrl: @@ -203,74 +204,58 @@ class TestExtractRunnerUrl: """Should extract base-url from dify-service-api config.""" runner = Mock() runner.pipeline_config = {} - pipeline_config = { - "ai": { - "dify-service-api": {"base-url": "https://api.dify.ai"} - } - } - url = extract_runner_url("dify-service-api", runner, pipeline_config) - assert url == "https://api.dify.ai" + pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}} + url = extract_runner_url('dify-service-api', runner, pipeline_config) + assert url == 'https://api.dify.ai' def test_n8n_service_api_extracts_url(self): """Should extract webhook-url from n8n-service-api config.""" runner = Mock() runner.pipeline_config = {} - pipeline_config = { - "ai": { - "n8n-service-api": {"webhook-url": "https://my.n8n.cloud/webhook"} - } - } - url = extract_runner_url("n8n-service-api", runner, pipeline_config) - assert url == "https://my.n8n.cloud/webhook" + pipeline_config = {'ai': {'n8n-service-api': {'webhook-url': 'https://my.n8n.cloud/webhook'}}} + url = extract_runner_url('n8n-service-api', runner, pipeline_config) + assert url == 'https://my.n8n.cloud/webhook' def test_coze_api_extracts_url(self): """Should extract api-base from coze-api config.""" runner = Mock() runner.pipeline_config = {} - pipeline_config = { - "ai": { - "coze-api": {"api-base": "https://api.coze.com"} - } - } - url = extract_runner_url("coze-api", runner, pipeline_config) - assert url == "https://api.coze.com" + pipeline_config = {'ai': {'coze-api': {'api-base': 'https://api.coze.com'}}} + url = extract_runner_url('coze-api', runner, pipeline_config) + assert url == 'https://api.coze.com' def test_langflow_api_extracts_url(self): """Should extract base-url from langflow-api config.""" runner = Mock() runner.pipeline_config = {} - pipeline_config = { - "ai": { - "langflow-api": {"base-url": "https://cloud.langflow.ai"} - } - } - url = extract_runner_url("langflow-api", runner, pipeline_config) - assert url == "https://cloud.langflow.ai" + pipeline_config = {'ai': {'langflow-api': {'base-url': 'https://cloud.langflow.ai'}}} + url = extract_runner_url('langflow-api', runner, pipeline_config) + assert url == 'https://cloud.langflow.ai' def test_unknown_runner_returns_none(self): """Unknown runner name should return None.""" runner = Mock() runner.pipeline_config = {} pipeline_config = {} - url = extract_runner_url("unknown-runner", runner, pipeline_config) + url = extract_runner_url('unknown-runner', runner, pipeline_config) assert url is None def test_none_runner_returns_none(self): """None runner should return None.""" - url = extract_runner_url("test", None, {}) + url = extract_runner_url('test', None, {}) assert url is None def test_runner_without_pipeline_config_returns_none(self): """Runner without pipeline_config attribute should return None.""" runner = Mock(spec=[]) # Empty spec means no attributes - url = extract_runner_url("test", runner, {}) + url = extract_runner_url('test', runner, {}) assert url is None def test_none_pipeline_config_returns_none(self): """None pipeline_config should return None.""" runner = Mock() runner.pipeline_config = {} - url = extract_runner_url("dify-service-api", runner, None) + url = extract_runner_url('dify-service-api', runner, None) assert url is None def test_missing_ai_config_returns_none(self): @@ -278,7 +263,7 @@ class TestExtractRunnerUrl: runner = Mock() runner.pipeline_config = {} pipeline_config = {} - url = extract_runner_url("dify-service-api", runner, pipeline_config) + url = extract_runner_url('dify-service-api', runner, pipeline_config) assert url is None @@ -289,19 +274,15 @@ class TestGetRunnerCategoryFromRunner: """Should extract URL and return correct category.""" runner = Mock() runner.pipeline_config = {} - pipeline_config = { - "ai": { - "dify-service-api": {"base-url": "https://api.dify.ai"} - } - } - category = get_runner_category_from_runner("dify-service-api", runner, pipeline_config) + pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}} + category = get_runner_category_from_runner('dify-service-api', runner, pipeline_config) assert category == RunnerCategory.CLOUD def test_returns_unknown_for_missing_url(self): """Should return UNKNOWN when URL cannot be extracted.""" runner = Mock() runner.pipeline_config = {} - category = get_runner_category_from_runner("unknown", runner, {}) + category = get_runner_category_from_runner('unknown', runner, {}) assert category == RunnerCategory.UNKNOWN @@ -310,9 +291,9 @@ class TestConstants: def test_runner_category_constants(self): """RunnerCategory should have LOCAL, CLOUD, UNKNOWN.""" - assert RunnerCategory.LOCAL == "local" - assert RunnerCategory.CLOUD == "cloud" - assert RunnerCategory.UNKNOWN == "unknown" + assert RunnerCategory.LOCAL == 'local' + assert RunnerCategory.CLOUD == 'cloud' + assert RunnerCategory.UNKNOWN == 'unknown' def test_cloud_domains_not_empty(self): """CLOUD_DOMAINS should not be empty.""" @@ -323,5 +304,5 @@ class TestConstants: assert len(LOCAL_PATTERNS) > 0 -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/vector/test_filter_utils.py b/tests/unit_tests/vector/test_filter_utils.py index f4eefb284..2bbf4a1c9 100644 --- a/tests/unit_tests/vector/test_filter_utils.py +++ b/tests/unit_tests/vector/test_filter_utils.py @@ -68,11 +68,7 @@ class TestNormalizeFilter: def test_normalize_filter_multiple_conditions(self): """Multiple top-level keys are AND-ed (returned as multiple triples).""" - result = normalize_filter({ - 'file_id': 'abc', - 'status': {'$ne': 'deleted'}, - 'created_at': {'$gte': 1700000000} - }) + result = normalize_filter({'file_id': 'abc', 'status': {'$ne': 'deleted'}, 'created_at': {'$gte': 1700000000}}) assert len(result) == 3 # Order should match dict iteration order @@ -149,11 +145,7 @@ class TestStripUnsupportedFields: ('file_id', '$eq', 'def'), ] - result = strip_unsupported_fields( - triples, - {'file_id', 'chunk_uuid'}, - field_aliases={'uuid': 'chunk_uuid'} - ) + result = strip_unsupported_fields(triples, {'file_id', 'chunk_uuid'}, field_aliases={'uuid': 'chunk_uuid'}) assert len(result) == 2 # 'uuid' should be resolved to 'chunk_uuid' @@ -169,7 +161,7 @@ class TestStripUnsupportedFields: result = strip_unsupported_fields( triples, {'file_id'}, # chunk_uuid not supported - field_aliases={'uuid': 'chunk_uuid'} + field_aliases={'uuid': 'chunk_uuid'}, ) assert result == [] @@ -207,4 +199,5 @@ class TestSupportedOpsConstant: def test_supported_ops_is_frozenset(self): """SUPPORTED_OPS is a frozenset for immutability.""" from collections.abc import Set - assert isinstance(SUPPORTED_OPS, Set) \ No newline at end of file + + assert isinstance(SUPPORTED_OPS, Set) diff --git a/tests/unit_tests/vector/test_mgr.py b/tests/unit_tests/vector/test_mgr.py index bf588a53c..997861383 100644 --- a/tests/unit_tests/vector/test_mgr.py +++ b/tests/unit_tests/vector/test_mgr.py @@ -55,6 +55,7 @@ class TestVectorDBManagerInitialization: # Run initialize synchronously for test import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) # Chroma should be instantiated @@ -76,6 +77,7 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_chroma_class.assert_called_once_with(mock_app) @@ -96,6 +98,7 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_qdrant_class.assert_called_once_with(mock_app) @@ -115,6 +118,7 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_seekdb_class.assert_called_once_with(mock_app) @@ -123,11 +127,7 @@ class TestVectorDBManagerInitialization: """Milvus config with custom URI.""" vdb_config = { 'use': 'milvus', - 'milvus': { - 'uri': 'http://localhost:19530', - 'token': 'root:Milvus', - 'db_name': 'langbot_db' - } + 'milvus': {'uri': 'http://localhost:19530', 'token': 'root:Milvus', 'db_name': 'langbot_db'}, } mock_app = self._create_mock_app(vdb_config) @@ -141,13 +141,11 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_milvus_class.assert_called_once_with( - mock_app, - uri='http://localhost:19530', - token='root:Milvus', - db_name='langbot_db' + mock_app, uri='http://localhost:19530', token='root:Milvus', db_name='langbot_db' ) def test_initialize_milvus_backend_defaults(self): @@ -165,24 +163,15 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) # Should use default values - mock_milvus_class.assert_called_once_with( - mock_app, - uri='./data/milvus.db', - token=None, - db_name='default' - ) + mock_milvus_class.assert_called_once_with(mock_app, uri='./data/milvus.db', token=None, db_name='default') def test_initialize_pgvector_with_connection_string(self): """pgvector with connection string.""" - vdb_config = { - 'use': 'pgvector', - 'pgvector': { - 'connection_string': 'postgresql://user:pass@host:5432/langbot' - } - } + vdb_config = {'use': 'pgvector', 'pgvector': {'connection_string': 'postgresql://user:pass@host:5432/langbot'}} mock_app = self._create_mock_app(vdb_config) mocks = self._make_vector_import_mocks() @@ -195,11 +184,11 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_pgvector_class.assert_called_once_with( - mock_app, - connection_string='postgresql://user:pass@host:5432/langbot' + mock_app, connection_string='postgresql://user:pass@host:5432/langbot' ) def test_initialize_pgvector_with_individual_params(self): @@ -211,8 +200,8 @@ class TestVectorDBManagerInitialization: 'port': 5433, 'database': 'vectordb', 'user': 'admin', - 'password': 'secret' - } + 'password': 'secret', + }, } mock_app = self._create_mock_app(vdb_config) @@ -226,15 +215,11 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_pgvector_class.assert_called_once_with( - mock_app, - host='db.example.com', - port=5433, - database='vectordb', - user='admin', - password='secret' + mock_app, host='db.example.com', port=5433, database='vectordb', user='admin', password='secret' ) def test_initialize_pgvector_defaults(self): @@ -252,15 +237,11 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_pgvector_class.assert_called_once_with( - mock_app, - host='localhost', - port=5432, - database='langbot', - user='postgres', - password='postgres' + mock_app, host='localhost', port=5432, database='langbot', user='postgres', password='postgres' ) def test_initialize_unknown_backend_defaults_to_chroma(self): @@ -278,6 +259,7 @@ class TestVectorDBManagerInitialization: mgr = VectorDBManager(mock_app) import asyncio + asyncio.get_event_loop().run_until_complete(mgr.initialize()) mock_chroma_class.assert_called_once_with(mock_app) @@ -335,4 +317,4 @@ class TestVectorDBManagerProxies: mgr.vector_db = mock_vector_db result = mgr.get_supported_search_types() - assert result == ['vector', 'full_text'] \ No newline at end of file + assert result == ['vector', 'full_text'] diff --git a/tests/unit_tests/vector/test_vdb_base.py b/tests/unit_tests/vector/test_vdb_base.py index f67aec163..427df9f19 100644 --- a/tests/unit_tests/vector/test_vdb_base.py +++ b/tests/unit_tests/vector/test_vdb_base.py @@ -39,6 +39,7 @@ class TestVectorDatabaseAbstractMethods: def test_abstract_methods_required(self): """Subclass must implement all abstract methods.""" + class IncompleteVectorDB(VectorDatabase): pass @@ -47,11 +48,21 @@ class TestVectorDatabaseAbstractMethods: def test_supported_search_types_default(self): """Default supported_search_types returns [VECTOR].""" + class MinimalVectorDB(VectorDatabase): async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): pass - async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None): + async def search( + self, + collection, + query_embedding, + k=5, + search_type='vector', + query_text='', + filter=None, + vector_weight=None, + ): pass async def delete_by_file_id(self, collection, file_id): @@ -71,11 +82,21 @@ class TestVectorDatabaseAbstractMethods: def test_list_by_filter_default_implementation(self): """list_by_filter has default implementation returning empty.""" + class MinimalVectorDB(VectorDatabase): async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): pass - async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None): + async def search( + self, + collection, + query_embedding, + k=5, + search_type='vector', + query_text='', + filter=None, + vector_weight=None, + ): pass async def delete_by_file_id(self, collection, file_id): @@ -93,9 +114,8 @@ class TestVectorDatabaseAbstractMethods: db = MinimalVectorDB() # list_by_filter should return empty list and -1 for total import asyncio - result = asyncio.get_event_loop().run_until_complete( - db.list_by_filter('test_collection') - ) + + result = asyncio.get_event_loop().run_until_complete(db.list_by_filter('test_collection')) assert result == ([], -1) @@ -105,14 +125,17 @@ class TestVectorDatabaseInterface: @pytest.fixture def mock_vector_db(self): """Create a minimal mock VectorDatabase for testing.""" + class MockVectorDB(VectorDatabase): def __init__(self): self.add_embeddings = AsyncMock() - self.search = AsyncMock(return_value={ - 'ids': [['id1', 'id2']], - 'distances': [[0.1, 0.2]], - 'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]] - }) + self.search = AsyncMock( + return_value={ + 'ids': [['id1', 'id2']], + 'distances': [[0.1, 0.2]], + 'metadatas': [[{'key': 'val1'}, {'key': 'val2'}]], + } + ) self.delete_by_file_id = AsyncMock() self.delete_by_filter = AsyncMock(return_value=5) self.get_or_create_collection = AsyncMock() @@ -121,7 +144,16 @@ class TestVectorDatabaseInterface: async def add_embeddings(self, collection, ids, embeddings_list, metadatas, documents=None): pass - async def search(self, collection, query_embedding, k=5, search_type='vector', query_text='', filter=None, vector_weight=None): + async def search( + self, + collection, + query_embedding, + k=5, + search_type='vector', + query_text='', + filter=None, + vector_weight=None, + ): pass async def delete_by_file_id(self, collection, file_id): @@ -146,7 +178,7 @@ class TestVectorDatabaseInterface: ids=['id1', 'id2'], embeddings_list=[[0.1, 0.2], [0.3, 0.4]], metadatas=[{'a': 1}, {'b': 2}], - documents=['doc1', 'doc2'] + documents=['doc1', 'doc2'], ) mock_vector_db.add_embeddings.assert_called_once() @@ -162,7 +194,7 @@ class TestVectorDatabaseInterface: search_type='hybrid', query_text='search text', filter={'file_id': 'abc'}, - vector_weight=0.7 + vector_weight=0.7, ) mock_vector_db.search.assert_called_once() @@ -170,4 +202,4 @@ class TestVectorDatabaseInterface: async def test_delete_by_filter_returns_int(self, mock_vector_db): """delete_by_filter returns int count.""" result = await mock_vector_db.delete_by_filter('test', {'file_id': 'abc'}) - assert isinstance(result, int) \ No newline at end of file + assert isinstance(result, int) diff --git a/tests/unit_tests/vector/test_vdb_filter_conversion.py b/tests/unit_tests/vector/test_vdb_filter_conversion.py index 5499b9087..cc79f62a4 100644 --- a/tests/unit_tests/vector/test_vdb_filter_conversion.py +++ b/tests/unit_tests/vector/test_vdb_filter_conversion.py @@ -5,6 +5,7 @@ Tests cover: - _build_milvus_expr: Milvus boolean expression string conversion - _build_pg_conditions: PostgreSQL SQLAlchemy conditions conversion """ + from __future__ import annotations from importlib import import_module @@ -122,11 +123,13 @@ class TestQdrantFilterConversion: """Multiple conditions are combined in must/must_not.""" qdrant_module = get_qdrant_module() - result = qdrant_module._build_qdrant_filter({ - 'file_id': 'abc', - 'status': {'$ne': 'deleted'}, - 'created_at': {'$gte': 100}, - }) + result = qdrant_module._build_qdrant_filter( + { + 'file_id': 'abc', + 'status': {'$ne': 'deleted'}, + 'created_at': {'$gte': 100}, + } + ) assert len(result.must) == 2 # file_id eq + created_at gte assert len(result.must_not) == 1 # status ne @@ -198,10 +201,12 @@ class TestMilvusFilterConversion: """Multiple conditions are joined with 'and'.""" milvus_module = get_milvus_module() - result = milvus_module._build_milvus_expr({ - 'file_id': 'abc', - 'chunk_uuid': {'$ne': 'def'}, - }) + result = milvus_module._build_milvus_expr( + { + 'file_id': 'abc', + 'chunk_uuid': {'$ne': 'def'}, + } + ) assert 'and' in result assert 'file_id == "abc"' in result assert 'chunk_uuid != "def"' in result @@ -272,6 +277,7 @@ class TestPgVectorFilterConversion: assert len(result) == 1 # Verify it's a SQLAlchemy BinaryExpression from sqlalchemy.sql.expression import BinaryExpression + assert isinstance(result[0], BinaryExpression) def test_ne_operator_creates_inequality_condition(self): @@ -321,10 +327,12 @@ class TestPgVectorFilterConversion: """Multiple conditions return list of conditions.""" pgvector_module = get_pgvector_module() - result = pgvector_module._build_pg_conditions({ - 'file_id': 'abc', - 'chunk_uuid': {'$ne': 'def'}, - }) + result = pgvector_module._build_pg_conditions( + { + 'file_id': 'abc', + 'chunk_uuid': {'$ne': 'def'}, + } + ) assert len(result) == 2 @@ -349,11 +357,13 @@ class TestPgVectorFilterConversion: """Only supported fields (text, file_id, chunk_uuid) are kept.""" pgvector_module = get_pgvector_module() - result = pgvector_module._build_pg_conditions({ - 'text': {'$ne': ''}, - 'file_id': 'abc', - 'chunk_uuid': {'$in': ['x', 'y']}, - 'unsupported': 'value', - }) + result = pgvector_module._build_pg_conditions( + { + 'text': {'$ne': ''}, + 'file_id': 'abc', + 'chunk_uuid': {'$in': ['x', 'y']}, + 'unsupported': 'value', + } + ) - assert len(result) == 3 # Only supported fields \ No newline at end of file + assert len(result) == 3 # Only supported fields diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index a8ead047e..11b530011 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,3 +1,3 @@ """ Test utilities package. -""" \ No newline at end of file +""" diff --git a/tests/utils/import_isolation.py b/tests/utils/import_isolation.py index 7d4487a8f..9f2b3c583 100644 --- a/tests/utils/import_isolation.py +++ b/tests/utils/import_isolation.py @@ -26,6 +26,7 @@ from unittest.mock import MagicMock class MockLifecycleControlScope(enum.Enum): """Mock enum for breaking circular import in core.entities.""" + APPLICATION = 'application' PLATFORM = 'platform' PLUGIN = 'plugin' @@ -190,4 +191,4 @@ def get_handler_modules_to_clear(handler_name: str) -> list[str]: 'langbot.pkg.pipeline.process.handler', 'langbot.pkg.pipeline.process.handlers', f'langbot.pkg.pipeline.process.handlers.{handler_name}', - ] \ No newline at end of file + ] From f4a6edf7ec1bb235b282c72043204edd78454a45 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:05:52 -0400 Subject: [PATCH 03/16] refactor(web): unify settings dialogs into single dialog with sidebar Merge API integration, model settings, account settings and storage analysis into one SettingsDialog with a shadcn inner sidebar for section switching. Preserve existing ?action= query-param deep links (showModelSettings / showAccountSettings / showApiIntegrationSettings / showStorageAnalysis) by mapping each to a section. Extract reusable panels and keep ModelsDialog as a thin wrapper for the dynamic-form model picker. --- .../AccountSettingsDialog.tsx | 181 ----- .../AccountSettingsPanel.tsx | 170 +++++ ...tionDialog.tsx => ApiIntegrationPanel.tsx} | 457 ++++++------ .../components/home-sidebar/HomeSidebar.tsx | 121 ++-- .../components/models-dialog/ModelsDialog.tsx | 675 +----------------- .../components/models-dialog/ModelsPanel.tsx | 666 +++++++++++++++++ .../settings-dialog/SettingsDialog.tsx | 204 ++++++ .../StorageAnalysisDialog.tsx | 410 ----------- .../StorageAnalysisPanel.tsx | 390 ++++++++++ web/src/i18n/locales/en-US.ts | 3 + web/src/i18n/locales/es-ES.ts | 3 + web/src/i18n/locales/ja-JP.ts | 3 + web/src/i18n/locales/ru-RU.ts | 3 + web/src/i18n/locales/th-TH.ts | 3 + web/src/i18n/locales/vi-VN.ts | 3 + web/src/i18n/locales/zh-Hans.ts | 3 + web/src/i18n/locales/zh-Hant.ts | 3 + 17 files changed, 1720 insertions(+), 1578 deletions(-) delete mode 100644 web/src/app/home/components/account-settings-dialog/AccountSettingsDialog.tsx create mode 100644 web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx rename web/src/app/home/components/api-integration-dialog/{ApiIntegrationDialog.tsx => ApiIntegrationPanel.tsx} (61%) create mode 100644 web/src/app/home/components/models-dialog/ModelsPanel.tsx create mode 100644 web/src/app/home/components/settings-dialog/SettingsDialog.tsx delete mode 100644 web/src/app/home/components/storage-analysis-dialog/StorageAnalysisDialog.tsx create mode 100644 web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx diff --git a/web/src/app/home/components/account-settings-dialog/AccountSettingsDialog.tsx b/web/src/app/home/components/account-settings-dialog/AccountSettingsDialog.tsx deleted file mode 100644 index b658c9fab..000000000 --- a/web/src/app/home/components/account-settings-dialog/AccountSettingsDialog.tsx +++ /dev/null @@ -1,181 +0,0 @@ -import * as React from 'react'; -import { useState, useEffect } from 'react'; -import { toast } from 'sonner'; -import { useTranslation } from 'react-i18next'; -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, - DialogDescription, -} from '@/components/ui/dialog'; -import { Button } from '@/components/ui/button'; -import { - Item, - ItemMedia, - ItemContent, - ItemTitle, - ItemDescription, - ItemActions, -} from '@/components/ui/item'; -import { httpClient } from '@/app/infra/http/HttpClient'; -import { systemInfo } from '@/app/infra/http'; -import { Loader2, ExternalLink, KeyRound, Layers } from 'lucide-react'; -import PasswordChangeDialog from '../password-change-dialog/PasswordChangeDialog'; - -interface AccountSettingsDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; -} - -export default function AccountSettingsDialog({ - open, - onOpenChange, -}: AccountSettingsDialogProps) { - const { t } = useTranslation(); - const [accountType, setAccountType] = useState<'local' | 'space'>('local'); - const [hasPassword, setHasPassword] = useState(false); - const [userEmail, setUserEmail] = useState(''); - const [loading, setLoading] = useState(true); - const [spaceBindLoading, setSpaceBindLoading] = useState(false); - const [passwordDialogOpen, setPasswordDialogOpen] = useState(false); - - useEffect(() => { - if (open) { - loadUserInfo(); - } - }, [open]); - - async function loadUserInfo() { - setLoading(true); - try { - const info = await httpClient.getUserInfo(); - setAccountType(info.account_type); - setHasPassword(info.has_password); - setUserEmail(info.user); - } catch { - toast.error(t('common.error')); - } finally { - setLoading(false); - } - } - - const handleBindSpace = async () => { - setSpaceBindLoading(true); - try { - const token = localStorage.getItem('token'); - if (!token) { - toast.error(t('common.error')); - setSpaceBindLoading(false); - return; - } - const currentOrigin = window.location.origin; - const redirectUri = `${currentOrigin}/auth/space/callback?mode=bind`; - // Pass token as state for security verification - const response = await httpClient.getSpaceAuthorizeUrl( - redirectUri, - token, - ); - window.location.href = response.authorize_url; - } catch { - toast.error(t('common.spaceLoginFailed')); - setSpaceBindLoading(false); - } - }; - - const handlePasswordDialogClose = (dialogOpen: boolean) => { - setPasswordDialogOpen(dialogOpen); - if (!dialogOpen) { - // Reload user info to update password status - loadUserInfo(); - } - }; - - return ( - <> - - - - {t('account.settings')} - {userEmail} - - - {loading ? ( -
- -
- ) : ( -
- {/* Password Item */} - - - - - - {t('account.passwordStatus')} - - {hasPassword - ? t('account.passwordSetDescription') - : t('account.setPasswordHint')} - - - - - - - - {/* Space Account Item */} - - - - - - {t('account.spaceStatus')} - - {accountType === 'space' - ? t('account.spaceBoundDescription') - : t('account.bindSpaceDescription')} - - - {accountType === 'local' && ( - - - - )} - -
- )} -
-
- - - - ); -} diff --git a/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx b/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx new file mode 100644 index 000000000..5cf7e4c8b --- /dev/null +++ b/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx @@ -0,0 +1,170 @@ +import { useState, useEffect } from 'react'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; +import { Button } from '@/components/ui/button'; +import { + Item, + ItemMedia, + ItemContent, + ItemTitle, + ItemDescription, + ItemActions, +} from '@/components/ui/item'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { systemInfo } from '@/app/infra/http'; +import { Loader2, ExternalLink, KeyRound, Layers } from 'lucide-react'; +import PasswordChangeDialog from '../password-change-dialog/PasswordChangeDialog'; + +interface AccountSettingsPanelProps { + // True when this panel is the active section and the dialog is open. + active: boolean; + onEmailResolved?: (email: string) => void; +} + +export default function AccountSettingsPanel({ + active, + onEmailResolved, +}: AccountSettingsPanelProps) { + const { t } = useTranslation(); + const [accountType, setAccountType] = useState<'local' | 'space'>('local'); + const [hasPassword, setHasPassword] = useState(false); + const [userEmail, setUserEmail] = useState(''); + const [loading, setLoading] = useState(true); + const [spaceBindLoading, setSpaceBindLoading] = useState(false); + const [passwordDialogOpen, setPasswordDialogOpen] = useState(false); + + useEffect(() => { + if (active) { + loadUserInfo(); + } + }, [active]); + + async function loadUserInfo() { + setLoading(true); + try { + const info = await httpClient.getUserInfo(); + setAccountType(info.account_type); + setHasPassword(info.has_password); + setUserEmail(info.user); + onEmailResolved?.(info.user); + } catch { + toast.error(t('common.error')); + } finally { + setLoading(false); + } + } + + const handleBindSpace = async () => { + setSpaceBindLoading(true); + try { + const token = localStorage.getItem('token'); + if (!token) { + toast.error(t('common.error')); + setSpaceBindLoading(false); + return; + } + const currentOrigin = window.location.origin; + const redirectUri = `${currentOrigin}/auth/space/callback?mode=bind`; + // Pass token as state for security verification + const response = await httpClient.getSpaceAuthorizeUrl( + redirectUri, + token, + ); + window.location.href = response.authorize_url; + } catch { + toast.error(t('common.spaceLoginFailed')); + setSpaceBindLoading(false); + } + }; + + const handlePasswordDialogClose = (dialogOpen: boolean) => { + setPasswordDialogOpen(dialogOpen); + if (!dialogOpen) { + // Reload user info to update password status + loadUserInfo(); + } + }; + + return ( +
+ {userEmail && ( +

{userEmail}

+ )} + + {loading ? ( +
+ +
+ ) : ( +
+ {/* Password Item */} + + + + + + {t('account.passwordStatus')} + + {hasPassword + ? t('account.passwordSetDescription') + : t('account.setPasswordHint')} + + + + + + + + {/* Space Account Item */} + + + + + + {t('account.spaceStatus')} + + {accountType === 'space' + ? t('account.spaceBoundDescription') + : t('account.bindSpaceDescription')} + + + {accountType === 'local' && ( + + + + )} + +
+ )} + + +
+ ); +} diff --git a/web/src/app/home/components/api-integration-dialog/ApiIntegrationDialog.tsx b/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx similarity index 61% rename from web/src/app/home/components/api-integration-dialog/ApiIntegrationDialog.tsx rename to web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx index 8ac3f496b..e45d5f501 100644 --- a/web/src/app/home/components/api-integration-dialog/ApiIntegrationDialog.tsx +++ b/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx @@ -3,7 +3,6 @@ import { useState, useEffect, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { toast } from 'sonner'; import { Copy, Check, Trash2, Plus } from 'lucide-react'; -import { useNavigate, useLocation, useSearchParams } from 'react-router-dom'; import { Dialog, DialogContent, @@ -55,20 +54,15 @@ interface Webhook { created_at: string; } -interface ApiIntegrationDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; +interface ApiIntegrationPanelProps { + // True when this panel is the active section and the dialog is open. + active: boolean; } -export default function ApiIntegrationDialog({ - open, - onOpenChange, -}: ApiIntegrationDialogProps) { +export default function ApiIntegrationPanel({ + active, +}: ApiIntegrationPanelProps) { const { t } = useTranslation(); - const navigate = useNavigate(); - const location = useLocation(); - const pathname = location.pathname; - const [searchParams] = useSearchParams(); const [activeTab, setActiveTab] = useState('apikeys'); const [apiKeys, setApiKeys] = useState([]); const [webhooks, setWebhooks] = useState([]); @@ -91,33 +85,7 @@ export default function ApiIntegrationDialog({ ); const [copiedKey, setCopiedKey] = useState(null); - // Sync URL with dialog state - useEffect(() => { - if (open) { - const params = new URLSearchParams(searchParams.toString()); - params.set('action', 'showApiIntegrationSettings'); - navigate(`${pathname}?${params.toString()}`, { - preventScrollReset: true, - }); - } - }, [open]); - - const handleOpenChange = (newOpen: boolean) => { - if (!newOpen && (deleteKeyId || deleteWebhookId)) { - return; - } - if (!newOpen) { - const params = new URLSearchParams(searchParams.toString()); - params.delete('action'); - const newUrl = params.toString() - ? `${pathname}?${params.toString()}` - : pathname; - navigate(newUrl, { preventScrollReset: true }); - } - onOpenChange(newOpen); - }; - - // 清理 body 样式,防止对话框关闭后页面无法交互 + // 清理 body 样式,防止嵌套对话框关闭后页面无法交互 useEffect(() => { if (!deleteKeyId && !deleteWebhookId) { const cleanup = () => { @@ -131,11 +99,11 @@ export default function ApiIntegrationDialog({ }, [deleteKeyId, deleteWebhookId]); useEffect(() => { - if (open) { + if (active) { loadApiKeys(); loadWebhooks(); } - }, [open]); + }, [active]); const loadApiKeys = async () => { setLoading(true); @@ -284,233 +252,216 @@ export default function ApiIntegrationDialog({ return ( <> - - - - {t('common.manageApiIntegration')} - +
+ + + + {t('common.apiKeys')} + + + {t('common.webhooks')} + + - - - - {t('common.apiKeys')} - - + {t('common.apiKeyHint')} +
+ +
+ +
- {/* API Keys Tab */} - -
- {t('common.apiKeyHint')} + {loading ? ( +
+ {t('common.loading')}
- -
- + ) : apiKeys.length === 0 ? ( +
+ {t('common.noApiKeys')}
- - {loading ? ( -
- {t('common.loading')} -
- ) : apiKeys.length === 0 ? ( -
- {t('common.noApiKeys')} -
- ) : ( -
- - - - - {t('common.name')} - - - {t('common.apiKeyValue')} - - - {t('common.actions')} - - - - - {apiKeys.map((item) => ( - - -
-
{item.name}
- {item.description && ( -
- {item.description} -
- )} -
-
- - - {maskApiKey(item.key)} - - - -
- - -
-
-
- ))} -
-
-
- )} - - - {/* Webhooks Tab */} - -
- {t('common.webhookHint')} -
- -
- -
- - {loading ? ( -
- {t('common.loading')} -
- ) : webhooks.length === 0 ? ( -
- {t('common.noWebhooks')} -
- ) : ( -
- - - - - {t('common.name')} - - - {t('common.webhookUrl')} - - - {t('common.webhookEnabled')} - - - {t('common.actions')} - - - - - {webhooks.map((webhook) => ( - - -
-
- {webhook.name} + ) : ( +
+
+ + + + {t('common.name')} + + + {t('common.apiKeyValue')} + + + {t('common.actions')} + + + + + {apiKeys.map((item) => ( + + +
+
{item.name}
+ {item.description && ( +
+ {item.description}
- {webhook.description && ( -
- {webhook.description} -
- )} -
-
- -
- - {webhook.url} - -
-
- - - handleToggleWebhook(webhook) - } - /> - - + )} + + + + + {maskApiKey(item.key)} + + + +
+ - - - ))} - -
-
- )} -
- +
+ + + ))} + + +
+ )} +
- - - -
-
+ {/* Webhooks Tab */} + +
+ {t('common.webhookHint')} +
+ +
+ +
+ + {loading ? ( +
+ {t('common.loading')} +
+ ) : webhooks.length === 0 ? ( +
+ {t('common.noWebhooks')} +
+ ) : ( +
+ + + + + {t('common.name')} + + + {t('common.webhookUrl')} + + + {t('common.webhookEnabled')} + + + {t('common.actions')} + + + + + {webhooks.map((webhook) => ( + + +
+
+ {webhook.name} +
+ {webhook.description && ( +
+ {webhook.description} +
+ )} +
+
+ +
+ + {webhook.url} + +
+
+ + handleToggleWebhook(webhook)} + /> + + + + +
+ ))} +
+
+
+ )} +
+ + {/* Create API Key Dialog */} diff --git a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx index 2eff56cfa..d1f3acd77 100644 --- a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx +++ b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx @@ -57,11 +57,12 @@ import { Avatar, AvatarFallback } from '@/components/ui/avatar'; import { LanguageSelector } from '@/components/ui/language-selector'; import { Badge } from '@/components/ui/badge'; import { Button } from '@/components/ui/button'; -import AccountSettingsDialog from '@/app/home/components/account-settings-dialog/AccountSettingsDialog'; -import ApiIntegrationDialog from '@/app/home/components/api-integration-dialog/ApiIntegrationDialog'; import NewVersionDialog from '@/app/home/components/new-version-dialog/NewVersionDialog'; -import ModelsDialog from '@/app/home/components/models-dialog/ModelsDialog'; -import StorageAnalysisDialog from '@/app/home/components/storage-analysis-dialog/StorageAnalysisDialog'; +import SettingsDialog, { + SettingsSection, + SETTINGS_ACTION_BY_SECTION, + SETTINGS_SECTION_BY_ACTION, +} from '@/app/home/components/settings-dialog/SettingsDialog'; import { GitHubRelease } from '@/app/infra/http/CloudServiceClient'; import { useAsyncTask, AsyncTaskStatus } from '@/hooks/useAsyncTask'; import { toast } from 'sonner'; @@ -1548,17 +1549,10 @@ export default function HomeSidebar({ }, [pathname]); useEffect(() => { - if (searchParams.get('action') === 'showModelSettings') { - setModelsDialogOpen(true); - } - if (searchParams.get('action') === 'showAccountSettings') { - setAccountSettingsOpen(true); - } - if (searchParams.get('action') === 'showApiIntegrationSettings') { - setApiKeyDialogOpen(true); - } - if (searchParams.get('action') === 'showStorageAnalysis') { - setStorageAnalysisOpen(true); + const action = searchParams.get('action'); + if (action && SETTINGS_SECTION_BY_ACTION[action]) { + setSettingsSection(SETTINGS_SECTION_BY_ACTION[action]); + setSettingsOpen(true); } }, [searchParams]); @@ -1567,15 +1561,14 @@ export default function HomeSidebar({ useState>(loadSectionState); const { theme, setTheme } = useTheme(); const { t } = useTranslation(); - const [accountSettingsOpen, setAccountSettingsOpen] = useState(false); - const [apiKeyDialogOpen, setApiKeyDialogOpen] = useState(false); + const [settingsOpen, setSettingsOpen] = useState(false); + const [settingsSection, setSettingsSection] = + useState('models'); const [latestRelease, setLatestRelease] = useState( null, ); const [hasNewVersion, setHasNewVersion] = useState(false); const [versionDialogOpen, setVersionDialogOpen] = useState(false); - const [modelsDialogOpen, setModelsDialogOpen] = useState(false); - const [storageAnalysisOpen, setStorageAnalysisOpen] = useState(false); const [userEmail, setUserEmail] = useState(''); const [starCount, setStarCount] = useState(null); const [userMenuOpen, setUserMenuOpen] = useState(false); @@ -1600,51 +1593,28 @@ export default function HomeSidebar({ setShowScrollHint(false); }, 250); } - function handleModelsDialogChange(open: boolean) { - setModelsDialogOpen(open); - if (open) { - const params = new URLSearchParams(searchParams.toString()); - params.set('action', 'showModelSettings'); - navigate(`${pathname}?${params.toString()}`, { - preventScrollReset: true, - }); - } else { - const params = new URLSearchParams(searchParams.toString()); - params.delete('action'); - const newUrl = params.toString() - ? `${pathname}?${params.toString()}` - : pathname; - navigate(newUrl, { preventScrollReset: true }); - } + function openSettings(section: SettingsSection) { + setSettingsSection(section); + setSettingsOpen(true); + const params = new URLSearchParams(searchParams.toString()); + params.set('action', SETTINGS_ACTION_BY_SECTION[section]); + navigate(`${pathname}?${params.toString()}`, { + preventScrollReset: true, + }); } - function handleAccountSettingsChange(open: boolean) { - setAccountSettingsOpen(open); - if (open) { - const params = new URLSearchParams(searchParams.toString()); - params.set('action', 'showAccountSettings'); - navigate(`${pathname}?${params.toString()}`, { - preventScrollReset: true, - }); - } else { - const params = new URLSearchParams(searchParams.toString()); - params.delete('action'); - const newUrl = params.toString() - ? `${pathname}?${params.toString()}` - : pathname; - navigate(newUrl, { preventScrollReset: true }); - } + function handleSettingsSectionChange(section: SettingsSection) { + setSettingsSection(section); + const params = new URLSearchParams(searchParams.toString()); + params.set('action', SETTINGS_ACTION_BY_SECTION[section]); + navigate(`${pathname}?${params.toString()}`, { + preventScrollReset: true, + }); } - function handleStorageAnalysisChange(open: boolean) { - setStorageAnalysisOpen(open); - if (open) { - const params = new URLSearchParams(searchParams.toString()); - params.set('action', 'showStorageAnalysis'); - navigate(`${pathname}?${params.toString()}`, { - preventScrollReset: true, - }); - } else { + function handleSettingsOpenChange(open: boolean) { + setSettingsOpen(open); + if (!open) { const params = new URLSearchParams(searchParams.toString()); params.delete('action'); const newUrl = params.toString() @@ -1917,7 +1887,7 @@ export default function HomeSidebar({ setApiKeyDialogOpen(true)} + onClick={() => openSettings('apiIntegration')} tooltip={t('common.apiIntegration')} > @@ -1930,7 +1900,7 @@ export default function HomeSidebar({ handleModelsDialogChange(true)} + onClick={() => openSettings('models')} tooltip={t('models.title')} > @@ -2018,7 +1988,10 @@ export default function HomeSidebar({ {/* Account actions */} handleAccountSettingsChange(true)} + onClick={() => { + setUserMenuOpen(false); + openSettings('account'); + }} > {t('account.settings')} @@ -2026,7 +1999,7 @@ export default function HomeSidebar({ { setUserMenuOpen(false); - handleStorageAnalysisChange(true); + openSettings('storageAnalysis'); }} > @@ -2123,27 +2096,17 @@ export default function HomeSidebar({ - - - - ); } diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index ccb03b3c7..dc27f3174 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -1,677 +1,42 @@ -import { useState, useEffect } from 'react'; -import { Plus, Boxes } from 'lucide-react'; -import { httpClient, systemInfo } from '@/app/infra/http/HttpClient'; -import { ModelProvider } from '@/app/infra/entities/api'; +import { useState } from 'react'; import { Dialog, DialogContent, DialogHeader, DialogTitle, } from '@/components/ui/dialog'; -import { Button } from '@/components/ui/button'; -import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; -import ProviderForm from './component/provider-form/ProviderForm'; -import { ProviderCard } from './components'; -import { - ExtraArg, - ModelType, - ScanModelsResult, - SelectedScannedModel, - TestResult, - ProviderModels, - LANGBOT_MODELS_PROVIDER_REQUESTER, -} from './types'; -import { CustomApiError } from '@/app/infra/entities/common'; +import ModelsPanel from './ModelsPanel'; interface ModelsDialogProps { open: boolean; onOpenChange: (open: boolean) => void; } -type ExtraArgValue = string | number | boolean | Record; - -function convertExtraArgsToObject( - args: ExtraArg[], -): Record { - const obj: Record = {}; - args.forEach((arg) => { - if (!arg.key.trim()) return; - if (arg.type === 'number') { - obj[arg.key] = Number(arg.value); - } else if (arg.type === 'boolean') { - obj[arg.key] = arg.value === 'true'; - } else if (arg.type === 'object') { - const raw = arg.value.trim() || '{}'; - let parsed: unknown; - try { - parsed = JSON.parse(raw); - } catch { - throw new Error(`Invalid JSON for extra parameter "${arg.key}"`); - } - if ( - parsed === null || - typeof parsed !== 'object' || - Array.isArray(parsed) - ) { - throw new Error(`Extra parameter "${arg.key}" must be a JSON object`); - } - obj[arg.key] = parsed as Record; - } else { - obj[arg.key] = arg.value; - } - }); - return obj; -} - -function parseContextLength( - value: number | null | undefined, - invalidMessage: string, -): number | null { - if (value === undefined || value === null) return null; - if (!Number.isInteger(value) || value <= 0) { - throw new Error(invalidMessage); - } - return value; -} - +// Standalone Models dialog. The unified Settings dialog renders +// directly; this wrapper is kept for places that open Models on its own +// (e.g. the model picker inside dynamic forms). export default function ModelsDialog({ open, onOpenChange, }: ModelsDialogProps) { const { t } = useTranslation(); - - const [providers, setProviders] = useState([]); - const [accountType, setAccountType] = useState<'local' | 'space'>('local'); - const [spaceCredits, setSpaceCredits] = useState(null); - - // Expanded providers and their models - const [expandedProviders, setExpandedProviders] = useState>( - new Set(), - ); - const [providerModels, setProviderModels] = useState< - Record - >({}); - const [loadingProviders, setLoadingProviders] = useState>( - new Set(), - ); - - // Provider form modal - const [providerFormOpen, setProviderFormOpen] = useState(false); - const [editingProviderId, setEditingProviderId] = useState( - null, - ); - - // Map of requester name -> support_type[] (from requester manifests), - // used to restrict which model-type tabs are shown when adding models. - const [requesterSupportTypes, setRequesterSupportTypes] = useState< - Record - >({}); - - // Popover states - const [addModelPopoverOpen, setAddModelPopoverOpen] = useState( - null, - ); - const [editModelPopoverOpen, setEditModelPopoverOpen] = useState< - string | null - >(null); - const [deleteConfirmOpen, setDeleteConfirmOpen] = useState( - null, - ); - - // Form states - const [isSubmitting, setIsSubmitting] = useState(false); - const [isTesting, setIsTesting] = useState(false); - const [testResult, setTestResult] = useState(null); - - // Track if providers have been loaded initially - const [providersLoaded, setProvidersLoaded] = useState(false); - - // Separate LangBot Models provider (hide when models service is disabled) - const langbotProvider = systemInfo.disable_models_service - ? undefined - : providers.find((p) => p.requester === LANGBOT_MODELS_PROVIDER_REQUESTER); - const otherProviders = providers.filter( - (p) => p.requester !== LANGBOT_MODELS_PROVIDER_REQUESTER, - ); - - useEffect(() => { - if (open) { - loadUserInfo(); - loadProviders(); - loadRequesterSupportTypes(); - } - }, [open]); - - // Auto-expand LangBot Models when no external providers exist - useEffect(() => { - if (providersLoaded && langbotProvider && otherProviders.length === 0) { - if (!expandedProviders.has(langbotProvider.uuid)) { - setExpandedProviders(new Set([langbotProvider.uuid])); - if (!providerModels[langbotProvider.uuid]) { - loadProviderModels(langbotProvider.uuid); - } - } - } - }, [providersLoaded, providers]); - - async function loadUserInfo() { - try { - const userInfo = await httpClient.getUserInfo(); - setAccountType(userInfo.account_type); - if (userInfo.account_type === 'space') { - const creditsInfo = await httpClient.getSpaceCredits(); - setSpaceCredits(creditsInfo.credits); - } - } catch { - setAccountType('local'); - } - } - - async function loadProviders() { - try { - const resp = await httpClient.getModelProviders(); - setProviders(resp.providers); - setProvidersLoaded(true); - } catch (err) { - console.error('Failed to load providers', err); - toast.error(t('models.loadError')); - } - } - - async function loadRequesterSupportTypes() { - try { - const resp = await httpClient.getProviderRequesters(); - const map: Record = {}; - for (const r of resp.requesters) { - map[r.name] = r.spec?.support_type ?? []; - } - setRequesterSupportTypes(map); - } catch (err) { - console.error('Failed to load requester support types', err); - } - } - - async function loadProviderModels(providerUuid: string, silent = false) { - if (loadingProviders.has(providerUuid)) return; - - if (!silent) { - setLoadingProviders((prev) => new Set(prev).add(providerUuid)); - } - try { - const [llmResp, embeddingResp, rerankResp] = await Promise.all([ - httpClient.getProviderLLMModels(providerUuid), - httpClient.getProviderEmbeddingModels(providerUuid), - httpClient.getProviderRerankModels(providerUuid), - ]); - setProviderModels((prev) => ({ - ...prev, - [providerUuid]: { - llm: llmResp.models, - embedding: embeddingResp.models, - rerank: rerankResp.models, - }, - })); - } catch (err) { - console.error('Failed to load models', err); - } finally { - if (!silent) { - setLoadingProviders((prev) => { - const next = new Set(prev); - next.delete(providerUuid); - return next; - }); - } - } - } - - function toggleProvider(providerUuid: string) { - setExpandedProviders((prev) => { - const next = new Set(prev); - if (next.has(providerUuid)) { - next.delete(providerUuid); - } else { - next.add(providerUuid); - if (!providerModels[providerUuid]) { - loadProviderModels(providerUuid); - } - } - return next; - }); - } - - function handleCreateProvider() { - setEditingProviderId(null); - setProviderFormOpen(true); - } - - function handleEditProvider(providerId: string) { - setEditingProviderId(providerId); - setProviderFormOpen(true); - } - - async function handleDeleteProvider(providerId: string) { - try { - await httpClient.deleteModelProvider(providerId); - toast.success(t('models.providerDeleted')); - loadProviders(); - } catch (err) { - toast.error(t('models.providerDeleteError') + (err as Error).message); - } - } - - async function handleSpaceLogin() { - try { - const token = localStorage.getItem('token'); - if (!token) { - toast.error(t('common.error')); - return; - } - const currentOrigin = window.location.origin; - const redirectUri = `${currentOrigin}/auth/space/callback?mode=bind`; - const response = await httpClient.getSpaceAuthorizeUrl( - redirectUri, - token, - ); - window.location.href = response.authorize_url; - } catch { - toast.error(t('common.spaceLoginFailed')); - } - } - - async function handleAddModel( - providerUuid: string, - modelType: ModelType, - name: string, - abilities: string[], - extraArgs: ExtraArg[], - contextLength?: number | null, - ) { - if (!name.trim()) { - toast.error(t('models.modelNameRequired')); - return; - } - setIsSubmitting(true); - try { - const extraArgsObj = convertExtraArgsToObject(extraArgs); - - if (modelType === 'llm') { - await httpClient.createProviderLLMModel({ - name, - provider_uuid: providerUuid, - abilities, - context_length: parseContextLength( - contextLength, - t('models.contextLengthInvalid'), - ), - extra_args: extraArgsObj, - } as never); - } else if (modelType === 'embedding') { - await httpClient.createProviderEmbeddingModel({ - name, - provider_uuid: providerUuid, - extra_args: extraArgsObj, - } as never); - } else { - await httpClient.createProviderRerankModel({ - name, - provider_uuid: providerUuid, - extra_args: extraArgsObj, - } as never); - } - setAddModelPopoverOpen(null); - loadProviderModels(providerUuid, true); - loadProviders(); - } catch (err) { - toast.error(t('models.createError') + (err as Error).message); - } finally { - setIsSubmitting(false); - } - } - - async function handleScanModels( - providerUuid: string, - modelType?: ModelType, - ): Promise { - try { - const resp = await httpClient.scanProviderModels(providerUuid, modelType); - return { - models: resp.models, - debug: resp.debug, - }; - } catch (err) { - toast.error(t('models.getModelListError') + (err as CustomApiError).msg); - return { models: [] }; - } - } - - async function handleAddScannedModels( - providerUuid: string, - modelType: ModelType, - models: SelectedScannedModel[], - ) { - if (models.length === 0) return; - - setIsSubmitting(true); - try { - for (const item of models) { - const effectiveType = item.model.type || modelType; - if (effectiveType === 'llm') { - await httpClient.createProviderLLMModel({ - name: item.model.name, - provider_uuid: providerUuid, - abilities: item.abilities, - context_length: item.model.context_length ?? null, - extra_args: {}, - } as never); - } else if (effectiveType === 'embedding') { - await httpClient.createProviderEmbeddingModel({ - name: item.model.name, - provider_uuid: providerUuid, - extra_args: {}, - } as never); - } else { - await httpClient.createProviderRerankModel({ - name: item.model.name, - provider_uuid: providerUuid, - extra_args: {}, - } as never); - } - } - setAddModelPopoverOpen(null); - loadProviderModels(providerUuid, true); - loadProviders(); - toast.success( - t('models.addSelectedModelsSuccess', { count: models.length }), - ); - } catch (err) { - toast.error(t('models.createError') + (err as CustomApiError).msg); - } finally { - setIsSubmitting(false); - } - } - - async function handleUpdateModel( - providerUuid: string, - modelId: string, - modelType: ModelType, - name: string, - abilities: string[], - extraArgs: ExtraArg[], - contextLength?: number | null, - ) { - if (!name.trim()) { - toast.error(t('models.modelNameRequired')); - return; - } - setIsSubmitting(true); - try { - const extraArgsObj = convertExtraArgsToObject(extraArgs); - - if (modelType === 'llm') { - await httpClient.updateProviderLLMModel(modelId, { - name, - provider_uuid: providerUuid, - abilities, - context_length: parseContextLength( - contextLength, - t('models.contextLengthInvalid'), - ), - extra_args: extraArgsObj, - } as never); - } else if (modelType === 'embedding') { - await httpClient.updateProviderEmbeddingModel(modelId, { - name, - provider_uuid: providerUuid, - extra_args: extraArgsObj, - } as never); - } else { - await httpClient.updateProviderRerankModel(modelId, { - name, - provider_uuid: providerUuid, - extra_args: extraArgsObj, - } as never); - } - setEditModelPopoverOpen(null); - loadProviderModels(providerUuid, true); - loadProviders(); - } catch (err) { - toast.error(t('models.saveError') + (err as Error).message); - } finally { - setIsSubmitting(false); - } - } - - async function handleDeleteModel( - providerUuid: string, - modelId: string, - modelType: ModelType, - ) { - try { - if (modelType === 'llm') { - await httpClient.deleteProviderLLMModel(modelId); - } else if (modelType === 'embedding') { - await httpClient.deleteProviderEmbeddingModel(modelId); - } else { - await httpClient.deleteProviderRerankModel(modelId); - } - toast.success(t('models.deleteSuccess')); - loadProviderModels(providerUuid, true); - loadProviders(); - } catch (err) { - toast.error(t('models.deleteError') + (err as Error).message); - } - } - - async function handleTestModel( - providerUuid: string, - name: string, - modelType: ModelType, - abilities: string[], - extraArgs: ExtraArg[], - ) { - setIsTesting(true); - setTestResult(null); - const startTime = Date.now(); - try { - const extraArgsObj = convertExtraArgsToObject(extraArgs); - - // Get the provider info - const provider = providers.find((p) => p.uuid === providerUuid); - const providerData = { - requester: provider?.requester || '', - base_url: provider?.base_url || '', - api_keys: provider?.api_keys || [], - }; - - if (modelType === 'llm') { - await httpClient.testLLMModel('_', { - uuid: '', - name, - provider_uuid: '', - provider: providerData, - abilities, - extra_args: extraArgsObj, - } as never); - } else if (modelType === 'embedding') { - await httpClient.testEmbeddingModel('_', { - uuid: '', - name, - provider_uuid: '', - provider: providerData, - extra_args: extraArgsObj, - } as never); - } else { - await httpClient.testRerankModel('_', { - uuid: '', - name, - provider_uuid: '', - provider: providerData, - extra_args: extraArgsObj, - } as never); - } - const duration = Date.now() - startTime; - setTestResult({ success: true, duration }); - } catch (err) { - console.error('Failed to test model', err); - toast.error(t('models.testError') + ': ' + (err as CustomApiError).msg); - setTestResult(null); - } finally { - setIsTesting(false); - } - } - - function handleFormClose() { - setProviderFormOpen(false); - loadProviders(); - // Refresh expanded providers - expandedProviders.forEach((uuid) => loadProviderModels(uuid)); - } - - function renderProviderCard( - provider: ModelProvider, - isLangBotModels: boolean = false, - ) { - return ( - toggleProvider(provider.uuid)} - onEditProvider={() => handleEditProvider(provider.uuid)} - onDeleteProvider={() => handleDeleteProvider(provider.uuid)} - onSpaceLogin={handleSpaceLogin} - onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)} - onCloseAddModel={() => setAddModelPopoverOpen(null)} - onAddModel={(modelType, name, abilities, extraArgs, contextLength) => - handleAddModel( - provider.uuid, - modelType, - name, - abilities, - extraArgs, - contextLength, - ) - } - onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)} - onAddScannedModels={(modelType, models) => - handleAddScannedModels(provider.uuid, modelType, models) - } - onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)} - onCloseEditModel={() => setEditModelPopoverOpen(null)} - onUpdateModel={( - modelId, - modelType, - name, - abilities, - extraArgs, - contextLength, - ) => - handleUpdateModel( - provider.uuid, - modelId, - modelType, - name, - abilities, - extraArgs, - contextLength, - ) - } - onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)} - onCloseDeleteConfirm={() => setDeleteConfirmOpen(null)} - onDeleteModel={(modelId, modelType) => - handleDeleteModel(provider.uuid, modelId, modelType) - } - onTestModel={(name, modelType, abilities, extraArgs) => - handleTestModel(provider.uuid, name, modelType, abilities, extraArgs) - } - isSubmitting={isSubmitting} - isTesting={isTesting} - testResult={testResult} - onResetTestResult={() => setTestResult(null)} - /> - ); - } + const [blocking, setBlocking] = useState(false); return ( - <> - { - if (!newOpen && providerFormOpen) return; - onOpenChange(newOpen); - }} - > - - - {t('models.title')} - - -
- {/* LangBot Models Card */} - {langbotProvider && renderProviderCard(langbotProvider, true)} - - {/* Add Provider Button */} -
- - {otherProviders.length === 0 - ? t( - systemInfo.disable_models_service - ? 'models.addProviderHintSimple' - : 'models.addProviderHint', - ) - : t('models.providerCount', { count: otherProviders.length })} - -
- -
-
- - {/* Provider List */} - {otherProviders.length === 0 ? ( -
- -

{t('models.noProviders')}

-
- ) : ( - otherProviders.map((p) => renderProviderCard(p)) - )} -
-
-
- - - - - - {editingProviderId - ? t('models.editProvider') - : t('models.addProvider')} - - - setProviderFormOpen(false)} - /> - - - + { + if (!newOpen && blocking) return; + onOpenChange(newOpen); + }} + > + + + {t('models.title')} + + + + ); } diff --git a/web/src/app/home/components/models-dialog/ModelsPanel.tsx b/web/src/app/home/components/models-dialog/ModelsPanel.tsx new file mode 100644 index 000000000..7dd32c135 --- /dev/null +++ b/web/src/app/home/components/models-dialog/ModelsPanel.tsx @@ -0,0 +1,666 @@ +import { useState, useEffect } from 'react'; +import { Plus, Boxes } from 'lucide-react'; +import { httpClient, systemInfo } from '@/app/infra/http/HttpClient'; +import { ModelProvider } from '@/app/infra/entities/api'; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; +import ProviderForm from './component/provider-form/ProviderForm'; +import { ProviderCard } from './components'; +import { + ExtraArg, + ModelType, + ScanModelsResult, + SelectedScannedModel, + TestResult, + ProviderModels, + LANGBOT_MODELS_PROVIDER_REQUESTER, +} from './types'; +import { CustomApiError } from '@/app/infra/entities/common'; + +interface ModelsPanelProps { + // True when this panel is the active section and the dialog is open. + active: boolean; + // Notify parent when a nested modal (provider form) should block outer-close. + onBlockingChange?: (blocking: boolean) => void; +} + +type ExtraArgValue = string | number | boolean | Record; + +function convertExtraArgsToObject( + args: ExtraArg[], +): Record { + const obj: Record = {}; + args.forEach((arg) => { + if (!arg.key.trim()) return; + if (arg.type === 'number') { + obj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + obj[arg.key] = arg.value === 'true'; + } else if (arg.type === 'object') { + const raw = arg.value.trim() || '{}'; + let parsed: unknown; + try { + parsed = JSON.parse(raw); + } catch { + throw new Error(`Invalid JSON for extra parameter "${arg.key}"`); + } + if ( + parsed === null || + typeof parsed !== 'object' || + Array.isArray(parsed) + ) { + throw new Error(`Extra parameter "${arg.key}" must be a JSON object`); + } + obj[arg.key] = parsed as Record; + } else { + obj[arg.key] = arg.value; + } + }); + return obj; +} + +function parseContextLength( + value: number | null | undefined, + invalidMessage: string, +): number | null { + if (value === undefined || value === null) return null; + if (!Number.isInteger(value) || value <= 0) { + throw new Error(invalidMessage); + } + return value; +} + +export default function ModelsPanel({ + active, + onBlockingChange, +}: ModelsPanelProps) { + const { t } = useTranslation(); + + const [providers, setProviders] = useState([]); + const [accountType, setAccountType] = useState<'local' | 'space'>('local'); + const [spaceCredits, setSpaceCredits] = useState(null); + + // Expanded providers and their models + const [expandedProviders, setExpandedProviders] = useState>( + new Set(), + ); + const [providerModels, setProviderModels] = useState< + Record + >({}); + const [loadingProviders, setLoadingProviders] = useState>( + new Set(), + ); + + // Provider form modal + const [providerFormOpen, setProviderFormOpen] = useState(false); + const [editingProviderId, setEditingProviderId] = useState( + null, + ); + + // Map of requester name -> support_type[] (from requester manifests), + // used to restrict which model-type tabs are shown when adding models. + const [requesterSupportTypes, setRequesterSupportTypes] = useState< + Record + >({}); + + // Popover states + const [addModelPopoverOpen, setAddModelPopoverOpen] = useState( + null, + ); + const [editModelPopoverOpen, setEditModelPopoverOpen] = useState< + string | null + >(null); + const [deleteConfirmOpen, setDeleteConfirmOpen] = useState( + null, + ); + + // Form states + const [isSubmitting, setIsSubmitting] = useState(false); + const [isTesting, setIsTesting] = useState(false); + const [testResult, setTestResult] = useState(null); + + // Track if providers have been loaded initially + const [providersLoaded, setProvidersLoaded] = useState(false); + + // Separate LangBot Models provider (hide when models service is disabled) + const langbotProvider = systemInfo.disable_models_service + ? undefined + : providers.find((p) => p.requester === LANGBOT_MODELS_PROVIDER_REQUESTER); + const otherProviders = providers.filter( + (p) => p.requester !== LANGBOT_MODELS_PROVIDER_REQUESTER, + ); + + useEffect(() => { + if (active) { + loadUserInfo(); + loadProviders(); + loadRequesterSupportTypes(); + } + }, [active]); + + // Notify parent of blocking state so it can guard outer-close. + useEffect(() => { + onBlockingChange?.(providerFormOpen); + }, [providerFormOpen, onBlockingChange]); + + // Auto-expand LangBot Models when no external providers exist + useEffect(() => { + if (providersLoaded && langbotProvider && otherProviders.length === 0) { + if (!expandedProviders.has(langbotProvider.uuid)) { + setExpandedProviders(new Set([langbotProvider.uuid])); + if (!providerModels[langbotProvider.uuid]) { + loadProviderModels(langbotProvider.uuid); + } + } + } + }, [providersLoaded, providers]); + + async function loadUserInfo() { + try { + const userInfo = await httpClient.getUserInfo(); + setAccountType(userInfo.account_type); + if (userInfo.account_type === 'space') { + const creditsInfo = await httpClient.getSpaceCredits(); + setSpaceCredits(creditsInfo.credits); + } + } catch { + setAccountType('local'); + } + } + + async function loadProviders() { + try { + const resp = await httpClient.getModelProviders(); + setProviders(resp.providers); + setProvidersLoaded(true); + } catch (err) { + console.error('Failed to load providers', err); + toast.error(t('models.loadError')); + } + } + + async function loadRequesterSupportTypes() { + try { + const resp = await httpClient.getProviderRequesters(); + const map: Record = {}; + for (const r of resp.requesters) { + map[r.name] = r.spec?.support_type ?? []; + } + setRequesterSupportTypes(map); + } catch (err) { + console.error('Failed to load requester support types', err); + } + } + + async function loadProviderModels(providerUuid: string, silent = false) { + if (loadingProviders.has(providerUuid)) return; + + if (!silent) { + setLoadingProviders((prev) => new Set(prev).add(providerUuid)); + } + try { + const [llmResp, embeddingResp, rerankResp] = await Promise.all([ + httpClient.getProviderLLMModels(providerUuid), + httpClient.getProviderEmbeddingModels(providerUuid), + httpClient.getProviderRerankModels(providerUuid), + ]); + setProviderModels((prev) => ({ + ...prev, + [providerUuid]: { + llm: llmResp.models, + embedding: embeddingResp.models, + rerank: rerankResp.models, + }, + })); + } catch (err) { + console.error('Failed to load models', err); + } finally { + if (!silent) { + setLoadingProviders((prev) => { + const next = new Set(prev); + next.delete(providerUuid); + return next; + }); + } + } + } + + function toggleProvider(providerUuid: string) { + setExpandedProviders((prev) => { + const next = new Set(prev); + if (next.has(providerUuid)) { + next.delete(providerUuid); + } else { + next.add(providerUuid); + if (!providerModels[providerUuid]) { + loadProviderModels(providerUuid); + } + } + return next; + }); + } + + function handleCreateProvider() { + setEditingProviderId(null); + setProviderFormOpen(true); + } + + function handleEditProvider(providerId: string) { + setEditingProviderId(providerId); + setProviderFormOpen(true); + } + + async function handleDeleteProvider(providerId: string) { + try { + await httpClient.deleteModelProvider(providerId); + toast.success(t('models.providerDeleted')); + loadProviders(); + } catch (err) { + toast.error(t('models.providerDeleteError') + (err as Error).message); + } + } + + async function handleSpaceLogin() { + try { + const token = localStorage.getItem('token'); + if (!token) { + toast.error(t('common.error')); + return; + } + const currentOrigin = window.location.origin; + const redirectUri = `${currentOrigin}/auth/space/callback?mode=bind`; + const response = await httpClient.getSpaceAuthorizeUrl( + redirectUri, + token, + ); + window.location.href = response.authorize_url; + } catch { + toast.error(t('common.spaceLoginFailed')); + } + } + + async function handleAddModel( + providerUuid: string, + modelType: ModelType, + name: string, + abilities: string[], + extraArgs: ExtraArg[], + contextLength?: number | null, + ) { + if (!name.trim()) { + toast.error(t('models.modelNameRequired')); + return; + } + setIsSubmitting(true); + try { + const extraArgsObj = convertExtraArgsToObject(extraArgs); + + if (modelType === 'llm') { + await httpClient.createProviderLLMModel({ + name, + provider_uuid: providerUuid, + abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), + extra_args: extraArgsObj, + } as never); + } else if (modelType === 'embedding') { + await httpClient.createProviderEmbeddingModel({ + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); + } else { + await httpClient.createProviderRerankModel({ + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); + } + setAddModelPopoverOpen(null); + loadProviderModels(providerUuid, true); + loadProviders(); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } finally { + setIsSubmitting(false); + } + } + + async function handleScanModels( + providerUuid: string, + modelType?: ModelType, + ): Promise { + try { + const resp = await httpClient.scanProviderModels(providerUuid, modelType); + return { + models: resp.models, + debug: resp.debug, + }; + } catch (err) { + toast.error(t('models.getModelListError') + (err as CustomApiError).msg); + return { models: [] }; + } + } + + async function handleAddScannedModels( + providerUuid: string, + modelType: ModelType, + models: SelectedScannedModel[], + ) { + if (models.length === 0) return; + + setIsSubmitting(true); + try { + for (const item of models) { + const effectiveType = item.model.type || modelType; + if (effectiveType === 'llm') { + await httpClient.createProviderLLMModel({ + name: item.model.name, + provider_uuid: providerUuid, + abilities: item.abilities, + context_length: item.model.context_length ?? null, + extra_args: {}, + } as never); + } else if (effectiveType === 'embedding') { + await httpClient.createProviderEmbeddingModel({ + name: item.model.name, + provider_uuid: providerUuid, + extra_args: {}, + } as never); + } else { + await httpClient.createProviderRerankModel({ + name: item.model.name, + provider_uuid: providerUuid, + extra_args: {}, + } as never); + } + } + setAddModelPopoverOpen(null); + loadProviderModels(providerUuid, true); + loadProviders(); + toast.success( + t('models.addSelectedModelsSuccess', { count: models.length }), + ); + } catch (err) { + toast.error(t('models.createError') + (err as CustomApiError).msg); + } finally { + setIsSubmitting(false); + } + } + + async function handleUpdateModel( + providerUuid: string, + modelId: string, + modelType: ModelType, + name: string, + abilities: string[], + extraArgs: ExtraArg[], + contextLength?: number | null, + ) { + if (!name.trim()) { + toast.error(t('models.modelNameRequired')); + return; + } + setIsSubmitting(true); + try { + const extraArgsObj = convertExtraArgsToObject(extraArgs); + + if (modelType === 'llm') { + await httpClient.updateProviderLLMModel(modelId, { + name, + provider_uuid: providerUuid, + abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), + extra_args: extraArgsObj, + } as never); + } else if (modelType === 'embedding') { + await httpClient.updateProviderEmbeddingModel(modelId, { + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); + } else { + await httpClient.updateProviderRerankModel(modelId, { + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); + } + setEditModelPopoverOpen(null); + loadProviderModels(providerUuid, true); + loadProviders(); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } finally { + setIsSubmitting(false); + } + } + + async function handleDeleteModel( + providerUuid: string, + modelId: string, + modelType: ModelType, + ) { + try { + if (modelType === 'llm') { + await httpClient.deleteProviderLLMModel(modelId); + } else if (modelType === 'embedding') { + await httpClient.deleteProviderEmbeddingModel(modelId); + } else { + await httpClient.deleteProviderRerankModel(modelId); + } + toast.success(t('models.deleteSuccess')); + loadProviderModels(providerUuid, true); + loadProviders(); + } catch (err) { + toast.error(t('models.deleteError') + (err as Error).message); + } + } + + async function handleTestModel( + providerUuid: string, + name: string, + modelType: ModelType, + abilities: string[], + extraArgs: ExtraArg[], + ) { + setIsTesting(true); + setTestResult(null); + const startTime = Date.now(); + try { + const extraArgsObj = convertExtraArgsToObject(extraArgs); + + // Get the provider info + const provider = providers.find((p) => p.uuid === providerUuid); + const providerData = { + requester: provider?.requester || '', + base_url: provider?.base_url || '', + api_keys: provider?.api_keys || [], + }; + + if (modelType === 'llm') { + await httpClient.testLLMModel('_', { + uuid: '', + name, + provider_uuid: '', + provider: providerData, + abilities, + extra_args: extraArgsObj, + } as never); + } else if (modelType === 'embedding') { + await httpClient.testEmbeddingModel('_', { + uuid: '', + name, + provider_uuid: '', + provider: providerData, + extra_args: extraArgsObj, + } as never); + } else { + await httpClient.testRerankModel('_', { + uuid: '', + name, + provider_uuid: '', + provider: providerData, + extra_args: extraArgsObj, + } as never); + } + const duration = Date.now() - startTime; + setTestResult({ success: true, duration }); + } catch (err) { + console.error('Failed to test model', err); + toast.error(t('models.testError') + ': ' + (err as CustomApiError).msg); + setTestResult(null); + } finally { + setIsTesting(false); + } + } + + function handleFormClose() { + setProviderFormOpen(false); + loadProviders(); + // Refresh expanded providers + expandedProviders.forEach((uuid) => loadProviderModels(uuid)); + } + + function renderProviderCard( + provider: ModelProvider, + isLangBotModels: boolean = false, + ) { + return ( + toggleProvider(provider.uuid)} + onEditProvider={() => handleEditProvider(provider.uuid)} + onDeleteProvider={() => handleDeleteProvider(provider.uuid)} + onSpaceLogin={handleSpaceLogin} + onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)} + onCloseAddModel={() => setAddModelPopoverOpen(null)} + onAddModel={(modelType, name, abilities, extraArgs, contextLength) => + handleAddModel( + provider.uuid, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) + } + onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)} + onAddScannedModels={(modelType, models) => + handleAddScannedModels(provider.uuid, modelType, models) + } + onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)} + onCloseEditModel={() => setEditModelPopoverOpen(null)} + onUpdateModel={( + modelId, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) => + handleUpdateModel( + provider.uuid, + modelId, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) + } + onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)} + onCloseDeleteConfirm={() => setDeleteConfirmOpen(null)} + onDeleteModel={(modelId, modelType) => + handleDeleteModel(provider.uuid, modelId, modelType) + } + onTestModel={(name, modelType, abilities, extraArgs) => + handleTestModel(provider.uuid, name, modelType, abilities, extraArgs) + } + isSubmitting={isSubmitting} + isTesting={isTesting} + testResult={testResult} + onResetTestResult={() => setTestResult(null)} + /> + ); + } + + return ( + <> +
+ {/* LangBot Models Card */} + {langbotProvider && renderProviderCard(langbotProvider, true)} + + {/* Add Provider Button */} +
+ + {otherProviders.length === 0 + ? t( + systemInfo.disable_models_service + ? 'models.addProviderHintSimple' + : 'models.addProviderHint', + ) + : t('models.providerCount', { count: otherProviders.length })} + +
+ +
+
+ + {/* Provider List */} + {otherProviders.length === 0 ? ( +
+ +

{t('models.noProviders')}

+
+ ) : ( + otherProviders.map((p) => renderProviderCard(p)) + )} +
+ + + + + + {editingProviderId + ? t('models.editProvider') + : t('models.addProvider')} + + + setProviderFormOpen(false)} + /> + + + + ); +} diff --git a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx new file mode 100644 index 000000000..f22fc779f --- /dev/null +++ b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx @@ -0,0 +1,204 @@ +import { useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { KeyRound, Sparkles, Settings, HardDrive } from 'lucide-react'; +import { + Dialog, + DialogContent, + DialogTitle, + DialogDescription, +} from '@/components/ui/dialog'; +import { + Sidebar, + SidebarContent, + SidebarGroup, + SidebarGroupContent, + SidebarMenu, + SidebarMenuButton, + SidebarMenuItem, + SidebarProvider, +} from '@/components/ui/sidebar'; +import { cn } from '@/lib/utils'; +import AccountSettingsPanel from '@/app/home/components/account-settings-dialog/AccountSettingsPanel'; +import ApiIntegrationPanel from '@/app/home/components/api-integration-dialog/ApiIntegrationPanel'; +import ModelsPanel from '@/app/home/components/models-dialog/ModelsPanel'; +import StorageAnalysisPanel from '@/app/home/components/storage-analysis-dialog/StorageAnalysisPanel'; + +// The set of settings sections shown in the unified dialog. The string values +// are also reused as the ?action= query param suffix so deep links keep working. +export type SettingsSection = + | 'account' + | 'apiIntegration' + | 'models' + | 'storageAnalysis'; + +// Map between a section id and its ?action= query value, so existing deep links +// (showAccountSettings, showApiIntegrationSettings, showModelSettings, +// showStorageAnalysis) continue to resolve to the right section. +export const SETTINGS_ACTION_BY_SECTION: Record = { + account: 'showAccountSettings', + apiIntegration: 'showApiIntegrationSettings', + models: 'showModelSettings', + storageAnalysis: 'showStorageAnalysis', +}; + +export const SETTINGS_SECTION_BY_ACTION: Record = + Object.fromEntries( + Object.entries(SETTINGS_ACTION_BY_SECTION).map(([section, action]) => [ + action, + section as SettingsSection, + ]), + ); + +interface SettingsDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + section: SettingsSection; + onSectionChange: (section: SettingsSection) => void; +} + +export default function SettingsDialog({ + open, + onOpenChange, + section, + onSectionChange, +}: SettingsDialogProps) { + const { t } = useTranslation(); + // A nested modal (e.g. the provider form) can request that we ignore + // outer-close until it is dismissed. + const [blocking, setBlocking] = useState(false); + + // Only the Models panel can raise a blocking nested modal. When we navigate + // away from it (or close the dialog) the panel unmounts without resetting, + // so clear the flag here to avoid getting stuck unable to close. + useEffect(() => { + if (section !== 'models' || !open) { + setBlocking(false); + } + }, [section, open]); + + const navItems: { + id: SettingsSection; + label: string; + icon: React.ReactNode; + }[] = [ + { + id: 'models', + label: t('models.title'), + icon: , + }, + { + id: 'apiIntegration', + label: t('common.apiIntegration'), + icon: , + }, + { + id: 'storageAnalysis', + label: t('storageAnalysis.title'), + icon: , + }, + { + id: 'account', + label: t('account.settings'), + icon: , + }, + ]; + + const activeLabel = + navItems.find((item) => item.id === section)?.label ?? + t('settingsDialog.title'); + + return ( + { + if (!newOpen && blocking) return; + onOpenChange(newOpen); + }} + > + + + {t('settingsDialog.title')} + + {activeLabel} + + + + + + +
+ {t('settingsDialog.title')} +
+ + {navItems.map((item) => ( + + onSectionChange(item.id)} + > + {item.icon} + {item.label} + + + ))} + +
+
+
+
+ +
+ {/* Mobile section switcher (sidebar is hidden on small screens) */} +
+ {navItems.map((item) => ( + + ))} +
+ +
+ {section === 'models' && ( + + )} + {section === 'apiIntegration' && ( + + )} + {section === 'storageAnalysis' && ( + + )} + {section === 'account' && ( + + )} +
+
+
+
+
+ ); +} diff --git a/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisDialog.tsx b/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisDialog.tsx deleted file mode 100644 index 210b93deb..000000000 --- a/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisDialog.tsx +++ /dev/null @@ -1,410 +0,0 @@ -'use client'; - -import { - type ReactNode, - useCallback, - useEffect, - useMemo, - useState, -} from 'react'; -import { useTranslation } from 'react-i18next'; -import { - AlertCircle, - Archive, - Clock, - Database, - FileWarning, - HardDrive, - RefreshCw, -} from 'lucide-react'; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from '@/components/ui/dialog'; -import { Button } from '@/components/ui/button'; -import { Badge } from '@/components/ui/badge'; -import { ScrollArea } from '@/components/ui/scroll-area'; -import { backendClient } from '@/app/infra/http'; - -interface StorageSection { - key: string; - path: string; - exists: boolean; - size_bytes: number; - file_count: number; -} - -interface CleanupCandidate { - key?: string; - name?: string; - size_bytes: number; - modified_at?: string; - date?: string; -} - -interface StorageAnalysis { - generated_at: string; - cleanup_policy: { - uploaded_file_retention_days: number; - log_retention_days: number; - }; - sections: StorageSection[]; - database: { - type: string; - monitoring_counts: Record; - binary_storage: { - count: number; - size_bytes: number | null; - }; - }; - cleanup_candidates: { - uploaded_files: CleanupCandidate[]; - log_files: CleanupCandidate[]; - }; - tasks: Record; -} - -interface StorageAnalysisDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; -} - -function formatBytes(bytes: number | null | undefined): string { - if (bytes === null || bytes === undefined) { - return '-'; - } - if (bytes < 1024) { - return `${bytes} B`; - } - const units = ['KB', 'MB', 'GB', 'TB']; - let value = bytes / 1024; - let unitIndex = 0; - while (value >= 1024 && unitIndex < units.length - 1) { - value /= 1024; - unitIndex += 1; - } - return `${value.toFixed(value >= 10 ? 1 : 2)} ${units[unitIndex]}`; -} - -export default function StorageAnalysisDialog({ - open, - onOpenChange, -}: StorageAnalysisDialogProps) { - const { t } = useTranslation(); - const [analysis, setAnalysis] = useState(null); - const [loading, setLoading] = useState(false); - const [error, setError] = useState(null); - - const loadAnalysis = useCallback(async () => { - setLoading(true); - setError(null); - try { - const result = await backendClient.get( - '/api/v1/system/storage-analysis', - ); - setAnalysis(result); - } catch (err) { - setError(err instanceof Error ? err.message : String(err)); - } finally { - setLoading(false); - } - }, []); - - useEffect(() => { - if (open) { - loadAnalysis(); - } - }, [loadAnalysis, open]); - - const totalBytes = useMemo(() => { - return ( - analysis?.sections.reduce((sum, item) => sum + item.size_bytes, 0) ?? 0 - ); - }, [analysis]); - - const uploadedCandidateBytes = useMemo(() => { - return ( - analysis?.cleanup_candidates.uploaded_files.reduce( - (sum, item) => sum + item.size_bytes, - 0, - ) ?? 0 - ); - }, [analysis]); - - const logCandidateBytes = useMemo(() => { - return ( - analysis?.cleanup_candidates.log_files.reduce( - (sum, item) => sum + item.size_bytes, - 0, - ) ?? 0 - ); - }, [analysis]); - - return ( - - - - - - {t('storageAnalysis.dialogTitle')} - - - {t('storageAnalysis.description')} - - - -
-
- {analysis - ? t('storageAnalysis.generatedAt', { - time: new Date(analysis.generated_at).toLocaleString(), - }) - : t('storageAnalysis.loading')} -
- -
- - -
- {error && ( -
- - {error} -
- )} - - {analysis && ( - <> -
- } - /> - } - /> - } - /> - } - /> -
- -
-

- - {t('storageAnalysis.cleanupPolicy')} -

-
- - - -
-
- -
-

- {t('storageAnalysis.sections')} -

-
- {analysis.sections.map((section) => ( -
-
-
- {t(`storageAnalysis.sectionNames.${section.key}`)} -
-
- {section.path || '-'} -
-
- {section.exists ? ( - - ) : ( - - {t('storageAnalysis.missing')} - - )} -
- {formatBytes(section.size_bytes)} -
-
- {section.file_count} -
-
- ))} -
-
- -
- - -
- -
- - -
- - )} -
-
-
-
- ); -} - -function SummaryItem({ - label, - value, - icon, - meta, -}: { - label: string; - value: string; - icon: ReactNode; - meta?: string; -}) { - return ( -
-
- {icon} - {label} -
-
- {value} - {meta && {meta}} -
-
- ); -} - -function PolicyItem({ label, value }: { label: string; value: string }) { - return ( -
-
{label}
-
{value}
-
- ); -} - -function MetricPanel({ - title, - values, -}: { - title: string; - values: Record; -}) { - return ( -
-

{title}

-
- {Object.entries(values).map(([key, value]) => ( -
- {key} - {value ?? '-'} -
- ))} -
-
- ); -} - -function CandidatePanel({ - title, - emptyText, - candidates, -}: { - title: string; - emptyText: string; - candidates: CleanupCandidate[]; -}) { - return ( -
-

- - {title} -

-
- {candidates.length === 0 ? ( -
- {emptyText} -
- ) : ( - candidates.slice(0, 8).map((candidate, index) => ( -
-
-
- {candidate.key ?? candidate.name} -
-
- {candidate.modified_at ?? candidate.date ?? '-'} -
-
-
- {formatBytes(candidate.size_bytes)} -
-
- )) - )} -
-
- ); -} diff --git a/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx b/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx new file mode 100644 index 000000000..0488616c1 --- /dev/null +++ b/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx @@ -0,0 +1,390 @@ +'use client'; + +import { + type ReactNode, + useCallback, + useEffect, + useMemo, + useState, +} from 'react'; +import { useTranslation } from 'react-i18next'; +import { + AlertCircle, + Archive, + Clock, + Database, + FileWarning, + HardDrive, + RefreshCw, +} from 'lucide-react'; +import { Button } from '@/components/ui/button'; +import { Badge } from '@/components/ui/badge'; +import { ScrollArea } from '@/components/ui/scroll-area'; +import { backendClient } from '@/app/infra/http'; + +interface StorageSection { + key: string; + path: string; + exists: boolean; + size_bytes: number; + file_count: number; +} + +interface CleanupCandidate { + key?: string; + name?: string; + size_bytes: number; + modified_at?: string; + date?: string; +} + +interface StorageAnalysis { + generated_at: string; + cleanup_policy: { + uploaded_file_retention_days: number; + log_retention_days: number; + }; + sections: StorageSection[]; + database: { + type: string; + monitoring_counts: Record; + binary_storage: { + count: number; + size_bytes: number | null; + }; + }; + cleanup_candidates: { + uploaded_files: CleanupCandidate[]; + log_files: CleanupCandidate[]; + }; + tasks: Record; +} + +interface StorageAnalysisPanelProps { + // True when this panel is the active section and the dialog is open. + active: boolean; +} + +function formatBytes(bytes: number | null | undefined): string { + if (bytes === null || bytes === undefined) { + return '-'; + } + if (bytes < 1024) { + return `${bytes} B`; + } + const units = ['KB', 'MB', 'GB', 'TB']; + let value = bytes / 1024; + let unitIndex = 0; + while (value >= 1024 && unitIndex < units.length - 1) { + value /= 1024; + unitIndex += 1; + } + return `${value.toFixed(value >= 10 ? 1 : 2)} ${units[unitIndex]}`; +} + +export default function StorageAnalysisPanel({ + active, +}: StorageAnalysisPanelProps) { + const { t } = useTranslation(); + const [analysis, setAnalysis] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + + const loadAnalysis = useCallback(async () => { + setLoading(true); + setError(null); + try { + const result = await backendClient.get( + '/api/v1/system/storage-analysis', + ); + setAnalysis(result); + } catch (err) { + setError(err instanceof Error ? err.message : String(err)); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + if (active) { + loadAnalysis(); + } + }, [loadAnalysis, active]); + + const totalBytes = useMemo(() => { + return ( + analysis?.sections.reduce((sum, item) => sum + item.size_bytes, 0) ?? 0 + ); + }, [analysis]); + + const uploadedCandidateBytes = useMemo(() => { + return ( + analysis?.cleanup_candidates.uploaded_files.reduce( + (sum, item) => sum + item.size_bytes, + 0, + ) ?? 0 + ); + }, [analysis]); + + const logCandidateBytes = useMemo(() => { + return ( + analysis?.cleanup_candidates.log_files.reduce( + (sum, item) => sum + item.size_bytes, + 0, + ) ?? 0 + ); + }, [analysis]); + + return ( +
+
+
+ {analysis + ? t('storageAnalysis.generatedAt', { + time: new Date(analysis.generated_at).toLocaleString(), + }) + : t('storageAnalysis.loading')} +
+ +
+ + +
+ {error && ( +
+ + {error} +
+ )} + + {analysis && ( + <> +
+ } + /> + } + /> + } + /> + } + /> +
+ +
+

+ + {t('storageAnalysis.cleanupPolicy')} +

+
+ + + +
+
+ +
+

+ {t('storageAnalysis.sections')} +

+
+ {analysis.sections.map((section) => ( +
+
+
+ {t(`storageAnalysis.sectionNames.${section.key}`)} +
+
+ {section.path || '-'} +
+
+ {section.exists ? ( + + ) : ( + + {t('storageAnalysis.missing')} + + )} +
+ {formatBytes(section.size_bytes)} +
+
+ {section.file_count} +
+
+ ))} +
+
+ +
+ + +
+ +
+ + +
+ + )} +
+
+
+ ); +} + +function SummaryItem({ + label, + value, + icon, + meta, +}: { + label: string; + value: string; + icon: ReactNode; + meta?: string; +}) { + return ( +
+
+ {icon} + {label} +
+
+ {value} + {meta && {meta}} +
+
+ ); +} + +function PolicyItem({ label, value }: { label: string; value: string }) { + return ( +
+
{label}
+
{value}
+
+ ); +} + +function MetricPanel({ + title, + values, +}: { + title: string; + values: Record; +}) { + return ( +
+

{title}

+
+ {Object.entries(values).map(([key, value]) => ( +
+ {key} + {value ?? '-'} +
+ ))} +
+
+ ); +} + +function CandidatePanel({ + title, + emptyText, + candidates, +}: { + title: string; + emptyText: string; + candidates: CleanupCandidate[]; +}) { + return ( +
+

+ + {title} +

+
+ {candidates.length === 0 ? ( +
+ {emptyText} +
+ ) : ( + candidates.slice(0, 8).map((candidate, index) => ( +
+
+
+ {candidate.key ?? candidate.name} +
+
+ {candidate.modified_at ?? candidate.date ?? '-'} +
+
+
+ {formatBytes(candidate.size_bytes)} +
+
+ )) + )} +
+
+ ); +} diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 91f0b6f31..f65a477f6 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -1386,6 +1386,9 @@ const enUS = { boxSessionCreated: 'Created', boxSessionLastUsed: 'Last used', }, + settingsDialog: { + title: 'Settings', + }, storageAnalysis: { title: 'Storage Analysis', description: 'Inspect storage usage and cleanup candidates', diff --git a/web/src/i18n/locales/es-ES.ts b/web/src/i18n/locales/es-ES.ts index 5f8028154..564563f5d 100644 --- a/web/src/i18n/locales/es-ES.ts +++ b/web/src/i18n/locales/es-ES.ts @@ -1419,6 +1419,9 @@ const esES = { boxSessionCreated: 'Creado', boxSessionLastUsed: 'Último uso', }, + settingsDialog: { + title: 'Configuración', + }, storageAnalysis: { title: 'Análisis de almacenamiento', description: diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index c8c651524..9769fa5cd 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -1392,6 +1392,9 @@ const jaJP = { boxSessionCreated: '作成日時', boxSessionLastUsed: '最終使用', }, + settingsDialog: { + title: '設定', + }, storageAnalysis: { title: 'ストレージ分析', description: 'ストレージ使用量とクリーンアップ候補を確認します', diff --git a/web/src/i18n/locales/ru-RU.ts b/web/src/i18n/locales/ru-RU.ts index 8ebb4fa2b..f2829f775 100644 --- a/web/src/i18n/locales/ru-RU.ts +++ b/web/src/i18n/locales/ru-RU.ts @@ -1395,6 +1395,9 @@ const ruRU = { boxSessionCreated: 'Создано', boxSessionLastUsed: 'Последнее использование', }, + settingsDialog: { + title: 'Настройки', + }, storageAnalysis: { title: 'Анализ хранилища', description: 'Проверьте использование хранилища и кандидатов на очистку', diff --git a/web/src/i18n/locales/th-TH.ts b/web/src/i18n/locales/th-TH.ts index ac976402d..c14bdad02 100644 --- a/web/src/i18n/locales/th-TH.ts +++ b/web/src/i18n/locales/th-TH.ts @@ -1364,6 +1364,9 @@ const thTH = { boxSessionCreated: 'สร้างเมื่อ', boxSessionLastUsed: 'ใช้ล่าสุด', }, + settingsDialog: { + title: 'การตั้งค่า', + }, storageAnalysis: { title: 'วิเคราะห์พื้นที่จัดเก็บ', description: 'ตรวจสอบการใช้พื้นที่จัดเก็บและรายการที่สามารถล้างได้', diff --git a/web/src/i18n/locales/vi-VN.ts b/web/src/i18n/locales/vi-VN.ts index be1e7754b..495d4d146 100644 --- a/web/src/i18n/locales/vi-VN.ts +++ b/web/src/i18n/locales/vi-VN.ts @@ -1388,6 +1388,9 @@ const viVN = { boxSessionCreated: 'Đã tạo', boxSessionLastUsed: 'Lần cuối sử dụng', }, + settingsDialog: { + title: 'Cài đặt', + }, storageAnalysis: { title: 'Phân tích lưu trữ', description: 'Kiểm tra dung lượng lưu trữ và các mục có thể dọn dẹp', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index f32f039aa..590578e76 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -1328,6 +1328,9 @@ const zhHans = { boxSessionCreated: '创建时间', boxSessionLastUsed: '最后使用', }, + settingsDialog: { + title: '设置', + }, storageAnalysis: { title: '存储分析', description: '查看存储占用和可清理文件', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 539b34c4f..4a8b31383 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -1327,6 +1327,9 @@ const zhHant = { boxSessionCreated: '建立時間', boxSessionLastUsed: '最後使用', }, + settingsDialog: { + title: '設定', + }, storageAnalysis: { title: '儲存分析', description: '查看儲存占用和可清理檔案', From b3c00fe6da5fba4293bb4ea87b60e056a3dc24fc Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:18:14 -0400 Subject: [PATCH 04/16] fix(web): use fixed height for settings dialog instead of 80vh Avoid the dialog stretching to fill tall viewports (large empty space). Pin to 620px with max-h-[85vh] fallback and narrow width to 52rem. --- .../app/home/components/settings-dialog/SettingsDialog.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx index f22fc779f..9b24b9e21 100644 --- a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx +++ b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx @@ -116,7 +116,7 @@ export default function SettingsDialog({ }} > @@ -128,7 +128,7 @@ export default function SettingsDialog({ @@ -154,7 +154,7 @@ export default function SettingsDialog({ -
+
{/* Mobile section switcher (sidebar is hidden on small screens) */}
{navItems.map((item) => ( From 716d7aca947e02b8eb3831bea0fd5798550aeea8 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:22:42 -0400 Subject: [PATCH 05/16] fix(web): fixed-height settings dialog, narrower sidebar Pin the dialog to a fixed 80vh (cap 800px) so switching sections no longer resizes it; panels scroll their own content internally. Override the SidebarProvider wrapper's default h-svh with h-full so both columns fill the dialog height. Narrow the inner settings sidebar to w-44. --- .../components/settings-dialog/SettingsDialog.tsx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx index 9b24b9e21..52b799c37 100644 --- a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx +++ b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx @@ -116,19 +116,21 @@ export default function SettingsDialog({ }} > {t('settingsDialog.title')} {activeLabel} - + {/* Override the SidebarProvider wrapper's default h-svh so the two + columns fill the dialog's fixed height instead of the viewport. */} + @@ -154,7 +156,7 @@ export default function SettingsDialog({ -
+
{/* Mobile section switcher (sidebar is hidden on small screens) */}
{navItems.map((item) => ( From d4699547e9aae84b524b4f4831b89705184580cc Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:27:10 -0400 Subject: [PATCH 06/16] i18n(web): localize Bots/Pipelines sidebar titles for es/th/vi MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit es-ES pipelines, th-TH bots+pipelines and vi-VN pipelines were left in English in the sidebar. Translate them: es Flujos, th บอท/ไปป์ไลน์, vi Quy trình. --- web/src/i18n/locales/es-ES.ts | 2 +- web/src/i18n/locales/th-TH.ts | 4 ++-- web/src/i18n/locales/vi-VN.ts | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/web/src/i18n/locales/es-ES.ts b/web/src/i18n/locales/es-ES.ts index 564563f5d..50b2dc2e3 100644 --- a/web/src/i18n/locales/es-ES.ts +++ b/web/src/i18n/locales/es-ES.ts @@ -844,7 +844,7 @@ const esES = { 'Una vez eliminada, la configuración de este servidor MCP no se podrá recuperar.', }, pipelines: { - title: 'Pipelines', + title: 'Flujos', description: 'Los Pipelines definen el flujo de procesamiento de eventos de mensajes, se usan para vincular a los Bots', createPipeline: 'Crear Pipeline', diff --git a/web/src/i18n/locales/th-TH.ts b/web/src/i18n/locales/th-TH.ts index c14bdad02..924e99c65 100644 --- a/web/src/i18n/locales/th-TH.ts +++ b/web/src/i18n/locales/th-TH.ts @@ -300,7 +300,7 @@ const thTH = { }, }, bots: { - title: 'Bot', + title: 'บอท', description: 'สร้างและจัดการ Bot ซึ่งเป็นจุดเชื่อมต่อของ LangBot กับแพลตฟอร์มต่างๆ', createBot: 'สร้าง Bot', @@ -819,7 +819,7 @@ const thTH = { 'เมื่อลบแล้ว การกำหนดค่าเซิร์ฟเวอร์ MCP นี้จะไม่สามารถกู้คืนได้', }, pipelines: { - title: 'Pipeline', + title: 'ไปป์ไลน์', description: 'Pipeline กำหนดกระบวนการประมวลผลเหตุการณ์ข้อความ ใช้เพื่อผูกกับ Bot', createPipeline: 'สร้าง Pipeline', diff --git a/web/src/i18n/locales/vi-VN.ts b/web/src/i18n/locales/vi-VN.ts index 495d4d146..55bb82c35 100644 --- a/web/src/i18n/locales/vi-VN.ts +++ b/web/src/i18n/locales/vi-VN.ts @@ -833,7 +833,7 @@ const viVN = { deleteMCPHint: 'Sau khi xóa, cấu hình máy chủ MCP này không thể khôi phục.', }, pipelines: { - title: 'Pipeline', + title: 'Quy trình', description: 'Pipeline xác định luồng xử lý sự kiện tin nhắn, dùng để liên kết với Bot', createPipeline: 'Tạo Pipeline', From 2d6faf9d5ed65a0349038279c1b2e14a0ba7e3a9 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:41:58 -0400 Subject: [PATCH 07/16] refactor(web): drop legacy ModelsDialog, use unified SettingsDialog everywhere The model-selector in dynamic forms (pipeline / knowledge base settings) still opened the old standalone ModelsDialog. Point it at the unified SettingsDialog (section pinned to models) and delete the now-unused ModelsDialog wrapper so only the new dialog remains. --- .../dynamic-form/DynamicFormItemComponent.tsx | 14 +++++-- .../components/models-dialog/ModelsDialog.tsx | 42 ------------------- 2 files changed, 11 insertions(+), 45 deletions(-) delete mode 100644 web/src/app/home/components/models-dialog/ModelsDialog.tsx diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx index 04979d75f..ce125c70b 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx @@ -61,7 +61,9 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu'; -import ModelsDialog from '@/app/home/components/models-dialog/ModelsDialog'; +import SettingsDialog, { + SettingsSection, +} from '@/app/home/components/settings-dialog/SettingsDialog'; export default function DynamicFormItemComponent({ config, @@ -87,6 +89,8 @@ export default function DynamicFormItemComponent({ ); const { t } = useTranslation(); const [modelsDialogOpen, setModelsDialogOpen] = useState(false); + const [settingsSection, setSettingsSection] = + useState('models'); const fetchLlmModels = () => { httpClient @@ -561,9 +565,11 @@ export default function DynamicFormItemComponent({ {t('models.title')} -
); @@ -913,9 +919,11 @@ export default function DynamicFormItemComponent({ {t('models.title')} -
diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx deleted file mode 100644 index dc27f3174..000000000 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ /dev/null @@ -1,42 +0,0 @@ -import { useState } from 'react'; -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, -} from '@/components/ui/dialog'; -import { useTranslation } from 'react-i18next'; -import ModelsPanel from './ModelsPanel'; - -interface ModelsDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; -} - -// Standalone Models dialog. The unified Settings dialog renders -// directly; this wrapper is kept for places that open Models on its own -// (e.g. the model picker inside dynamic forms). -export default function ModelsDialog({ - open, - onOpenChange, -}: ModelsDialogProps) { - const { t } = useTranslation(); - const [blocking, setBlocking] = useState(false); - - return ( - { - if (!newOpen && blocking) return; - onOpenChange(newOpen); - }} - > - - - {t('models.title')} - - - - - ); -} From e9db858dcc2d4a5a46bb3c03b03050420c160cb9 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 05:50:44 -0400 Subject: [PATCH 08/16] feat(web): unified header for settings dialog, shorter sidebar labels - Add a shared section header (icon + title + description) with right padding so the dialog close X no longer overlaps panel content, and every tab now shares the same top layout for a consistent look. - Shorten inner sidebar nav labels (Models/API/Storage/Account) via new settingsDialog.nav.* i18n keys across all 8 locales. - Add common.apiIntegrationDescription and account.settingsDescription for the new header. --- .../settings-dialog/SettingsDialog.tsx | 37 +++++++++++++++---- web/src/i18n/locales/en-US.ts | 9 +++++ web/src/i18n/locales/es-ES.ts | 9 +++++ web/src/i18n/locales/ja-JP.ts | 9 +++++ web/src/i18n/locales/ru-RU.ts | 9 +++++ web/src/i18n/locales/th-TH.ts | 9 +++++ web/src/i18n/locales/vi-VN.ts | 9 +++++ web/src/i18n/locales/zh-Hans.ts | 8 ++++ web/src/i18n/locales/zh-Hant.ts | 8 ++++ 9 files changed, 100 insertions(+), 7 deletions(-) diff --git a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx index 52b799c37..8984b6439 100644 --- a/web/src/app/home/components/settings-dialog/SettingsDialog.tsx +++ b/web/src/app/home/components/settings-dialog/SettingsDialog.tsx @@ -79,33 +79,42 @@ export default function SettingsDialog({ const navItems: { id: SettingsSection; label: string; + title: string; + description: string; icon: React.ReactNode; }[] = [ { id: 'models', - label: t('models.title'), + label: t('settingsDialog.nav.models'), + title: t('models.title'), + description: t('models.description'), icon: , }, { id: 'apiIntegration', - label: t('common.apiIntegration'), + label: t('settingsDialog.nav.api'), + title: t('common.apiIntegration'), + description: t('common.apiIntegrationDescription'), icon: , }, { id: 'storageAnalysis', - label: t('storageAnalysis.title'), + label: t('settingsDialog.nav.storage'), + title: t('storageAnalysis.title'), + description: t('storageAnalysis.description'), icon: , }, { id: 'account', - label: t('account.settings'), + label: t('settingsDialog.nav.account'), + title: t('account.settings'), + description: t('account.settingsDescription'), icon: , }, ]; - const activeLabel = - navItems.find((item) => item.id === section)?.label ?? - t('settingsDialog.title'); + const activeItem = navItems.find((item) => item.id === section); + const activeLabel = activeItem?.title ?? t('settingsDialog.title'); return ( + {/* Unified section header (shared across all tabs). The extra + right padding keeps the title clear of the dialog's close X. */} +
+

+ {activeItem?.icon} + {activeItem?.title} +

+ {activeItem?.description && ( +

+ {activeItem.description} +

+ )} +
+
{section === 'models' && ( Date: Tue, 16 Jun 2026 06:02:20 -0400 Subject: [PATCH 09/16] refactor(web): unify settings panel layouts with shared toolbar/body - Add PanelToolbar/PanelBody primitives so all four settings tabs share the same top-toolbar + scrollable-body rhythm under the unified header. - API panel: drop the heavy gray shadowed TabsList; move the create action into the toolbar next to the tabs, lighten per-tab hints. - Storage panel: reuse PanelToolbar for the generated-at/refresh bar. - Account panel: wrap content in PanelBody for consistent padding. - Models panel: keep the pinned LangBot Models (Space) card at the very top, above the add-custom-provider row (intentional pin), using PanelBody instead of a top toolbar. --- .../bot-form/RoutingRulesEditor.tsx | 1 - .../AccountSettingsPanel.tsx | 5 +- .../ApiIntegrationPanel.tsx | 390 +++++++++--------- .../dynamic-form/DynamicFormItemConfig.ts | 3 +- .../components/models-dialog/ModelsPanel.tsx | 22 +- .../settings-dialog/panel-layout.tsx | 45 ++ .../StorageAnalysisPanel.tsx | 5 +- .../overview-cards/SystemStatusCards.tsx | 1 - .../pipeline-form/PipelineFormComponent.tsx | 5 +- .../plugin-installed/PluginCardVO.ts | 4 +- web/src/app/infra/entities/common.ts | 2 +- web/src/app/infra/entities/form/dynamic.ts | 2 +- web/src/app/infra/entities/plugin/index.ts | 2 +- web/src/app/wizard/page.tsx | 2 +- 14 files changed, 262 insertions(+), 227 deletions(-) create mode 100644 web/src/app/home/components/settings-dialog/panel-layout.tsx diff --git a/web/src/app/home/bots/components/bot-form/RoutingRulesEditor.tsx b/web/src/app/home/bots/components/bot-form/RoutingRulesEditor.tsx index 42b866261..f8c7efbdf 100644 --- a/web/src/app/home/bots/components/bot-form/RoutingRulesEditor.tsx +++ b/web/src/app/home/bots/components/bot-form/RoutingRulesEditor.tsx @@ -48,7 +48,6 @@ interface PipelineOption { } interface RoutingRulesEditorProps { - // eslint-disable-next-line @typescript-eslint/no-explicit-any form: UseFormReturn; pipelineNameList: PipelineOption[]; } diff --git a/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx b/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx index 5cf7e4c8b..0795f413e 100644 --- a/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx +++ b/web/src/app/home/components/account-settings-dialog/AccountSettingsPanel.tsx @@ -14,6 +14,7 @@ import { httpClient } from '@/app/infra/http/HttpClient'; import { systemInfo } from '@/app/infra/http'; import { Loader2, ExternalLink, KeyRound, Layers } from 'lucide-react'; import PasswordChangeDialog from '../password-change-dialog/PasswordChangeDialog'; +import { PanelBody } from '../settings-dialog/panel-layout'; interface AccountSettingsPanelProps { // True when this panel is the active section and the dialog is open. @@ -86,7 +87,7 @@ export default function AccountSettingsPanel({ }; return ( -
+ {userEmail && (

{userEmail}

)} @@ -165,6 +166,6 @@ export default function AccountSettingsPanel({ onOpenChange={handlePasswordDialogClose} hasPassword={hasPassword} /> -
+ ); } diff --git a/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx b/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx index e45d5f501..9cb14a696 100644 --- a/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx +++ b/web/src/app/home/components/api-integration-dialog/ApiIntegrationPanel.tsx @@ -36,6 +36,7 @@ import { } from '@/components/ui/alert-dialog'; import * as AlertDialogPrimitive from '@radix-ui/react-alert-dialog'; import { backendClient } from '@/app/infra/http'; +import { PanelToolbar } from '../settings-dialog/panel-layout'; interface ApiKey { id: number; @@ -252,216 +253,209 @@ export default function ApiIntegrationPanel({ return ( <> -
- - - - {t('common.apiKeys')} - - - {t('common.webhooks')} - + + + + {t('common.apiKeys')} + {t('common.webhooks')} + {activeTab === 'apikeys' ? ( + + ) : ( + + )} + - {/* API Keys Tab */} - -
- {t('common.apiKeyHint')} + {/* API Keys Tab */} + +

+ {t('common.apiKeyHint')} +

+ + {loading ? ( +
+ {t('common.loading')}
- -
- + ) : apiKeys.length === 0 ? ( +
+ {t('common.noApiKeys')}
- - {loading ? ( -
- {t('common.loading')} -
- ) : apiKeys.length === 0 ? ( -
- {t('common.noApiKeys')} -
- ) : ( -
- - - - - {t('common.name')} - - - {t('common.apiKeyValue')} - - - {t('common.actions')} - - - - - {apiKeys.map((item) => ( - - -
-
{item.name}
- {item.description && ( -
- {item.description} -
- )} -
-
- - - {maskApiKey(item.key)} - - - -
- - -
-
-
- ))} -
-
-
- )} - - - {/* Webhooks Tab */} - -
- {t('common.webhookHint')} -
- -
- -
- - {loading ? ( -
- {t('common.loading')} -
- ) : webhooks.length === 0 ? ( -
- {t('common.noWebhooks')} -
- ) : ( -
- - - - - {t('common.name')} - - - {t('common.webhookUrl')} - - - {t('common.webhookEnabled')} - - - {t('common.actions')} - - - - - {webhooks.map((webhook) => ( - - -
-
- {webhook.name} + ) : ( +
+
+ + + + {t('common.name')} + + + {t('common.apiKeyValue')} + + + {t('common.actions')} + + + + + {apiKeys.map((item) => ( + + +
+
{item.name}
+ {item.description && ( +
+ {item.description}
- {webhook.description && ( -
- {webhook.description} -
- )} -
-
- -
- - {webhook.url} - -
-
- - handleToggleWebhook(webhook)} - /> - - + )} + + + + + {maskApiKey(item.key)} + + + +
+ - - - ))} - -
-
- )} -
- -
+
+ + + ))} + + +
+ )} + + + {/* Webhooks Tab */} + +

+ {t('common.webhookHint')} +

+ + {loading ? ( +
+ {t('common.loading')} +
+ ) : webhooks.length === 0 ? ( +
+ {t('common.noWebhooks')} +
+ ) : ( +
+ + + + + {t('common.name')} + + + {t('common.webhookUrl')} + + + {t('common.webhookEnabled')} + + + {t('common.actions')} + + + + + {webhooks.map((webhook) => ( + + +
+
+ {webhook.name} +
+ {webhook.description && ( +
+ {webhook.description} +
+ )} +
+
+ +
+ + {webhook.url} + +
+
+ + handleToggleWebhook(webhook)} + /> + + + + +
+ ))} +
+
+
+ )} +
+ {/* Create API Key Dialog */} diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts b/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts index b11e09d23..62f1dbf0a 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemConfig.ts @@ -47,7 +47,6 @@ export function parseDynamicFormItemType(value: string): DynamicFormItemType { export function getDefaultValues( itemConfigList: IDynamicFormItemSchema[], - // eslint-disable-next-line @typescript-eslint/no-explicit-any ): Record { return itemConfigList.reduce( (acc, item) => { @@ -59,7 +58,7 @@ export function getDefaultValues( acc[item.name] = item.default; return acc; }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any + {} as Record, ); } diff --git a/web/src/app/home/components/models-dialog/ModelsPanel.tsx b/web/src/app/home/components/models-dialog/ModelsPanel.tsx index 7dd32c135..a71bf758d 100644 --- a/web/src/app/home/components/models-dialog/ModelsPanel.tsx +++ b/web/src/app/home/components/models-dialog/ModelsPanel.tsx @@ -23,6 +23,7 @@ import { LANGBOT_MODELS_PROVIDER_REQUESTER, } from './types'; import { CustomApiError } from '@/app/infra/entities/common'; +import { PanelBody } from '../settings-dialog/panel-layout'; interface ModelsPanelProps { // True when this panel is the active section and the dialog is open. @@ -611,12 +612,13 @@ export default function ModelsPanel({ return ( <> -
- {/* LangBot Models Card */} + + {/* LangBot Models (Space) provider card is intentionally pinned to the + top, above the "add custom provider" action row. */} {langbotProvider && renderProviderCard(langbotProvider, true)} - {/* Add Provider Button */} -
+ {/* Add-provider row: stays below the pinned card by design. */} +
{otherProviders.length === 0 ? t( @@ -626,12 +628,10 @@ export default function ModelsPanel({ ) : t('models.providerCount', { count: otherProviders.length })} -
- -
+
{/* Provider List */} @@ -643,7 +643,7 @@ export default function ModelsPanel({ ) : ( otherProviders.map((p) => renderProviderCard(p)) )} -
+
diff --git a/web/src/app/home/components/settings-dialog/panel-layout.tsx b/web/src/app/home/components/settings-dialog/panel-layout.tsx new file mode 100644 index 000000000..27ad7a3c0 --- /dev/null +++ b/web/src/app/home/components/settings-dialog/panel-layout.tsx @@ -0,0 +1,45 @@ +import * as React from 'react'; +import { cn } from '@/lib/utils'; + +/** + * Shared layout primitives for the settings-dialog panels. + * + * Every section renders under the dialog's unified header, so the panels + * themselves should share the same vertical rhythm: an optional top toolbar + * (meta on the left, primary action on the right) followed by a scrollable + * body with consistent padding. Keeping these in one place is what makes the + * tabs feel like one cohesive surface instead of four separately-styled views. + */ + +export function PanelToolbar({ + className, + children, +}: { + className?: string; + children: React.ReactNode; +}) { + return ( +
+ {children} +
+ ); +} + +export function PanelBody({ + className, + children, +}: { + className?: string; + children: React.ReactNode; +}) { + return ( +
+ {children} +
+ ); +} diff --git a/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx b/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx index 0488616c1..833f5e853 100644 --- a/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx +++ b/web/src/app/home/components/storage-analysis-dialog/StorageAnalysisPanel.tsx @@ -21,6 +21,7 @@ import { Button } from '@/components/ui/button'; import { Badge } from '@/components/ui/badge'; import { ScrollArea } from '@/components/ui/scroll-area'; import { backendClient } from '@/app/infra/http'; +import { PanelToolbar } from '../settings-dialog/panel-layout'; interface StorageSection { key: string; @@ -137,7 +138,7 @@ export default function StorageAnalysisPanel({ return (
-
+
{analysis ? t('storageAnalysis.generatedAt', { @@ -156,7 +157,7 @@ export default function StorageAnalysisPanel({ /> {t('storageAnalysis.refresh')} -
+
diff --git a/web/src/app/home/monitoring/components/overview-cards/SystemStatusCards.tsx b/web/src/app/home/monitoring/components/overview-cards/SystemStatusCards.tsx index 8b2f65ea8..46be50a9e 100644 --- a/web/src/app/home/monitoring/components/overview-cards/SystemStatusCards.tsx +++ b/web/src/app/home/monitoring/components/overview-cards/SystemStatusCards.tsx @@ -82,7 +82,6 @@ export default function SystemStatusCard({ fetchStatus(); const interval = setInterval(fetchStatus, 30_000); return () => clearInterval(interval); - // eslint-disable-next-line react-hooks/exhaustive-deps }, [fetchStatus, refreshKey]); const pluginOk = pluginStatus diff --git a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx index 863c22022..5df4baa88 100644 --- a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx +++ b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx @@ -323,7 +323,6 @@ export default function PipelineFormComponent({ const isFirstEmission = !initializedStagesRef.current.has(stageKey); const currentValues = - // eslint-disable-next-line @typescript-eslint/no-explicit-any (form.getValues(formName) as Record) || {}; form.setValue(formName, { ...currentValues, @@ -368,7 +367,6 @@ export default function PipelineFormComponent({ )?.[stage.name] || {} } @@ -402,7 +400,6 @@ export default function PipelineFormComponent({ )?.[stage.name] || {} } @@ -445,7 +442,7 @@ export default function PipelineFormComponent({ // make the locked selector display a scope that is NOT the one actually in // effect. Coerce the displayed/saved value to the forced template so the UI // truthfully reflects runtime behavior. - // eslint-disable-next-line @typescript-eslint/no-explicit-any + const stageInitialValues: Record = (form.watch(formName) as Record)?.[stage.name] || {}; const effectiveInitialValues = diff --git a/web/src/app/home/plugins/components/plugin-installed/PluginCardVO.ts b/web/src/app/home/plugins/components/plugin-installed/PluginCardVO.ts index 279161b43..c1546886d 100644 --- a/web/src/app/home/plugins/components/plugin-installed/PluginCardVO.ts +++ b/web/src/app/home/plugins/components/plugin-installed/PluginCardVO.ts @@ -9,7 +9,7 @@ export interface IPluginCardVO { enabled: boolean; priority: number; install_source: string; - install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any + install_info: Record; status: string; components: PluginComponent[]; debug: boolean; @@ -27,7 +27,7 @@ export class PluginCardVO implements IPluginCardVO { priority: number; debug: boolean; install_source: string; - install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any + install_info: Record; status: string; components: PluginComponent[]; hasUpdate?: boolean; diff --git a/web/src/app/infra/entities/common.ts b/web/src/app/infra/entities/common.ts index 4f04c2ef5..942fc5450 100644 --- a/web/src/app/infra/entities/common.ts +++ b/web/src/app/infra/entities/common.ts @@ -21,7 +21,7 @@ export interface ComponentManifest { version?: string; author?: string; }; - spec: Record; // eslint-disable-line @typescript-eslint/no-explicit-any + spec: Record; } export interface CustomApiError { diff --git a/web/src/app/infra/entities/form/dynamic.ts b/web/src/app/infra/entities/form/dynamic.ts index e2dca5c32..44fed3acf 100644 --- a/web/src/app/infra/entities/form/dynamic.ts +++ b/web/src/app/infra/entities/form/dynamic.ts @@ -8,7 +8,7 @@ export const SYSTEM_FIELD_PREFIX = '__system.'; export interface IShowIfCondition { field: string; operator: 'eq' | 'neq' | 'in'; - // eslint-disable-next-line @typescript-eslint/no-explicit-any + value: any; } diff --git a/web/src/app/infra/entities/plugin/index.ts b/web/src/app/infra/entities/plugin/index.ts index ad661211f..9431fae59 100644 --- a/web/src/app/infra/entities/plugin/index.ts +++ b/web/src/app/infra/entities/plugin/index.ts @@ -10,7 +10,7 @@ export interface Plugin { debug: boolean; enabled: boolean; install_source: string; - install_info: Record; // eslint-disable-line @typescript-eslint/no-explicit-any + install_info: Record; components: PluginComponent[]; } diff --git a/web/src/app/wizard/page.tsx b/web/src/app/wizard/page.tsx index 1fd393750..a3afd07c4 100644 --- a/web/src/app/wizard/page.tsx +++ b/web/src/app/wizard/page.tsx @@ -86,7 +86,7 @@ export default function WizardPage() { const [selectedAdapter, setSelectedAdapter] = useState(null); const [selectedRunner, setSelectedRunner] = useState(null); const [botName, setBotName] = useState(''); - // eslint-disable-next-line @typescript-eslint/no-unused-vars + const [botDescription, _setBotDescription] = useState(''); const [adapterConfig, setAdapterConfig] = useState>( {}, From 4e45886647a47cd503b5e9ae2216834986b6b735 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Tue, 16 Jun 2026 06:04:59 -0400 Subject: [PATCH 10/16] style(web): show Models above API Integration in main sidebar footer --- .../components/home-sidebar/HomeSidebar.tsx | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx index d1f3acd77..d57e34d33 100644 --- a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx +++ b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx @@ -1883,19 +1883,6 @@ export default function HomeSidebar({ {/* Footer */} - {/* API Integration entry */} - - - openSettings('apiIntegration')} - tooltip={t('common.apiIntegration')} - > - - {t('common.apiIntegration')} - - - - {/* Models entry */} @@ -1909,6 +1896,19 @@ export default function HomeSidebar({ + {/* API Integration entry */} + + + openSettings('apiIntegration')} + tooltip={t('common.apiIntegration')} + > + + {t('common.apiIntegration')} + + + + {/* User menu using sidebar-07 nav-user DropdownMenu pattern */} From b3c6de20723826aa46d2166142d80fded4da8e69 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Tue, 16 Jun 2026 13:34:17 +0000 Subject: [PATCH 11/16] [codex] cover frontend CRUD smoke flows (#2253) * test: cover frontend CRUD smoke flows * test: add bot CRUD smoke coverage * test: add bot/pipeline advanced flows and cross-resource tests - Bot enable/disable toggle with state persistence - Bot detail tab switching (Configuration, Logs, Sessions) - Bot form dirty state and save button behavior - Bot name validation error display - Pipeline tab switching (Configuration, Dashboard) - Pipeline form dirty state - Pipeline name validation error display - Cross-resource flow: create pipeline then bind to bot - Empty states for bots, pipelines, knowledge bases, MCP servers --- web/tests/e2e/crud-smoke.spec.ts | 455 ++++++++++++++++++++++++++ web/tests/e2e/fixtures/langbot-api.ts | 426 +++++++++++++++++++++++- 2 files changed, 877 insertions(+), 4 deletions(-) create mode 100644 web/tests/e2e/crud-smoke.spec.ts diff --git a/web/tests/e2e/crud-smoke.spec.ts b/web/tests/e2e/crud-smoke.spec.ts new file mode 100644 index 000000000..bba7467b3 --- /dev/null +++ b/web/tests/e2e/crud-smoke.spec.ts @@ -0,0 +1,455 @@ +import { expect, Page, test } from '@playwright/test'; + +import { installLangBotApiMocks } from './fixtures/langbot-api'; + +async function save(page: Page) { + const button = page.getByRole('button', { name: /^Save$/ }); + await expect(button).toBeEnabled(); + await button.click(); +} + +async function submit(page: Page) { + await page.getByRole('button', { name: /^Submit$/ }).click(); +} + +async function confirmDelete(page: Page) { + await page + .getByRole('dialog') + .getByRole('button', { name: /^Confirm Delete$/ }) + .click(); +} + +test.describe('frontend CRUD smoke flows', () => { + test('creates, edits, and deletes a bot', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/bots?id=new'); + + await expect(page.locator('input[name="name"]')).toBeVisible(); + await page.locator('input[name="name"]').fill('Support Bot'); + await page + .locator('input[name="description"]') + .fill('Answers customer support questions.'); + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + + await expect(page).toHaveURL(/\/home\/bots\?id=bot-1$/); + await page.reload(); + await expect(page.locator('input[name="name"]')).toHaveValue('Support Bot'); + + await page + .locator('input[name="description"]') + .fill('Answers customer support questions with context.'); + await save(page); + await expect(page.locator('input[name="description"]')).toHaveValue( + 'Answers customer support questions with context.', + ); + + await page.getByRole('button', { name: /^Delete$/ }).click(); + await confirmDelete(page); + + await expect(page).toHaveURL(/\/home\/bots$/); + await expect(page.getByText('Select a bot from the sidebar')).toBeVisible(); + }); + + test('creates, edits, and deletes a pipeline', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/pipelines?id=new'); + + await expect(page.locator('input[name="basic.name"]')).toBeVisible(); + await page.locator('input[name="basic.name"]').fill('Escalation Pipeline'); + await page + .locator('input[name="basic.description"]') + .fill('Routes urgent customer issues.'); + await submit(page); + + await expect(page).toHaveURL(/\/home\/pipelines\?id=pipeline-1$/); + await page.reload(); + await expect(page.locator('input[name="basic.name"]')).toHaveValue( + 'Escalation Pipeline', + ); + + await page + .locator('input[name="basic.description"]') + .fill('Routes urgent customer issues to operators.'); + await save(page); + await expect(page.locator('input[name="basic.description"]')).toHaveValue( + 'Routes urgent customer issues to operators.', + ); + + await page.getByRole('button', { name: /^Delete$/ }).click(); + await confirmDelete(page); + + await expect(page).toHaveURL(/\/home\/pipelines$/); + await expect( + page.getByText('Select a pipeline from the sidebar'), + ).toBeVisible(); + }); + + test('creates, edits, and deletes a knowledge base', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/knowledge?id=new'); + + await expect(page.locator('input[name="name"]')).toBeVisible(); + await page.locator('input[name="name"]').fill('Support Knowledge'); + await page + .locator('input[name="description"]') + .fill('Source material for support answers.'); + await submit(page); + + await expect(page).toHaveURL(/\/home\/knowledge\?id=knowledge-1$/); + await page.reload(); + await expect(page.locator('input[name="name"]')).toHaveValue( + 'Support Knowledge', + ); + await page.waitForTimeout(600); + + await page + .locator('input[name="description"]') + .fill('Updated source material for support answers.'); + await save(page); + await expect(page.locator('input[name="description"]')).toHaveValue( + 'Updated source material for support answers.', + ); + + await page.getByRole('button', { name: /^Delete$/ }).click(); + await confirmDelete(page); + + await expect(page).toHaveURL(/\/home\/knowledge$/); + await expect( + page.getByText('Select a knowledge base from the sidebar'), + ).toBeVisible(); + }); + + test('creates, edits, and deletes an MCP server', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/mcp?id=new'); + + await expect(page.locator('input[name="name"]')).toBeVisible(); + await page.locator('input[name="name"]').fill('playwright-mcp'); + await page + .locator('input[name="url"]') + .fill('https://mcp.example.test/sse'); + await submit(page); + + await expect(page).toHaveURL(/\/home\/mcp\?id=playwright-mcp$/); + await page.reload(); + await expect(page.locator('input[name="name"]')).toHaveValue( + 'playwright-mcp', + ); + + await page + .locator('input[name="url"]') + .fill('https://mcp.example.test/updated-sse'); + await save(page); + await expect(page.locator('input[name="url"]')).toHaveValue( + 'https://mcp.example.test/updated-sse', + ); + + await page.getByRole('button', { name: /^Delete$/ }).click(); + await confirmDelete(page); + + await expect(page).toHaveURL(/\/home\/mcp$/); + await expect( + page.getByText('Select an MCP server from the sidebar'), + ).toBeVisible(); + }); + + test('updates and deletes a manually-created skill', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/skills?action=create'); + + await page.locator('#display_name').fill('Release Notes'); + await page.locator('#name').fill('release_notes'); + await page.locator('#description').fill('Drafts release notes.'); + await page + .locator('#instructions') + .fill('Summarize merged changes for the next release.'); + await save(page); + + await expect(page).toHaveURL(/\/home\/skills\?id=release_notes$/); + await page.reload(); + await expect(page.locator('#description')).toHaveValue( + 'Drafts release notes.', + ); + + await page + .locator('#description') + .fill('Drafts concise release notes for maintainers.'); + await expect(page.locator('#description')).toHaveValue( + 'Drafts concise release notes for maintainers.', + ); + await save(page); + await page.reload(); + await expect(page.locator('#description')).toHaveValue( + 'Drafts concise release notes for maintainers.', + ); + await expect(page.locator('#instructions')).toHaveValue( + 'Summarize merged changes for the next release.', + ); + + await page.getByRole('button', { name: /^Delete$/ }).click(); + await confirmDelete(page); + + await expect(page).toHaveURL(/\/home\/add-extension$/); + }); +}); + +test.describe('bot advanced flows', () => { + test('toggles bot enable/disable state', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a bot first + await page.goto('/home/bots?id=new'); + await page.locator('input[name="name"]').fill('Toggle Test Bot'); + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + + await expect(page).toHaveURL(/\/home\/bots\?id=bot-1$/); + + // Wait for the enable switch to load (it's fetched via getBot) + await expect(page.locator('#bot-enable-switch')).toBeVisible({ + timeout: 5000, + }); + + // Verify initial state is enabled + await expect(page.locator('#bot-enable-switch')).toBeChecked(); + + // Toggle to disabled + await page.locator('#bot-enable-switch').click(); + await expect(page.locator('#bot-enable-switch')).not.toBeChecked(); + + // Reload and verify state persisted + await page.reload(); + await expect(page.locator('#bot-enable-switch')).not.toBeChecked(); + }); + + test('switches between bot detail tabs', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a bot + await page.goto('/home/bots?id=new'); + await page.locator('input[name="name"]').fill('Tab Test Bot'); + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + + // Verify we're on the Configuration tab + await expect( + page.getByRole('tab', { name: /Configuration/ }), + ).toHaveAttribute('data-state', 'active'); + await expect(page.locator('input[name="name"]')).toBeVisible(); + + // Switch to Logs tab + await page.getByRole('tab', { name: /Logs/ }).click(); + await expect(page.getByRole('tab', { name: /Logs/ })).toHaveAttribute( + 'data-state', + 'active', + ); + + // Switch to Sessions tab + await page.getByRole('tab', { name: /Sessions/ }).click(); + await expect(page.getByRole('tab', { name: /Sessions/ })).toHaveAttribute( + 'data-state', + 'active', + ); + + // Switch back to Configuration + await page.getByRole('tab', { name: /Configuration/ }).click(); + await expect(page.locator('input[name="name"]')).toBeVisible(); + }); + + test('save button is disabled when form is clean', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a bot + await page.goto('/home/bots?id=new'); + await page.locator('input[name="name"]').fill('Clean Form Bot'); + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + + // After creation, save button should be disabled (form is clean) + const saveButton = page.getByRole('button', { name: /^Save$/ }); + await expect(saveButton).toBeDisabled(); + + // Edit the form + await page.locator('input[name="description"]').fill('New description'); + await expect(saveButton).toBeEnabled(); + + // Save + await saveButton.click(); + await expect(saveButton).toBeDisabled(); + }); + + test('shows validation error when bot name is empty', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/bots?id=new'); + + // Select adapter but leave name empty + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + + // Should show validation error for name (zod validation) + await expect(page.getByText(/cannot be empty/i)).toBeVisible(); + await expect(page).toHaveURL(/\/home\/bots\?id=new$/); + }); +}); + +test.describe('pipeline advanced flows', () => { + test('switches to monitoring tab from pipeline detail', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a pipeline + await page.goto('/home/pipelines?id=new'); + await page.locator('input[name="basic.name"]').fill('Tab Test Pipeline'); + await submit(page); + + // Verify we're on the Configuration tab + await expect( + page.getByRole('tab', { name: /Configuration/ }), + ).toHaveAttribute('data-state', 'active'); + + // Switch to Monitoring tab (labeled "Dashboard" in the pipeline context) + // Skip Debug tab as it requires WebSocket connection + await page.getByRole('tab', { name: /Dashboard/ }).click(); + await expect(page.getByRole('tab', { name: /Dashboard/ })).toHaveAttribute( + 'data-state', + 'active', + ); + + // Switch back to Configuration + await page.getByRole('tab', { name: /Configuration/ }).click(); + await expect(page.locator('input[name="basic.name"]')).toBeVisible(); + }); + + test('save button reflects form dirty state', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a pipeline + await page.goto('/home/pipelines?id=new'); + await page.locator('input[name="basic.name"]').fill('Dirty Form Pipeline'); + await submit(page); + + // Wait for the page to fully load and form to reset + await page.waitForTimeout(500); + + // Edit the form - use the name field which definitely triggers dirty state + await page + .locator('input[name="basic.name"]') + .fill('Dirty Form Pipeline Updated'); + const saveButton = page.getByRole('button', { name: /^Save$/ }); + await expect(saveButton).toBeEnabled(); + + // Save + await saveButton.click(); + // Wait for save to complete + await page.waitForTimeout(500); + }); + + test('shows validation error when pipeline name is empty', async ({ + page, + }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/pipelines?id=new'); + + // Submit without filling name + await submit(page); + + // Should show validation error for name (zod validation) + await expect(page.getByText(/cannot be empty/i)).toBeVisible(); + await expect(page).toHaveURL(/\/home\/pipelines\?id=new$/); + }); +}); + +test.describe('cross-resource flows', () => { + test('creates a pipeline then binds it to a bot', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + // Create a pipeline first + await page.goto('/home/pipelines?id=new'); + await page.locator('input[name="basic.name"]').fill('Production Pipeline'); + await submit(page); + await expect(page).toHaveURL(/\/home\/pipelines\?id=pipeline-1$/); + + // Create a bot + await page.goto('/home/bots?id=new'); + await page.locator('input[name="name"]').fill('Bound Bot'); + await page.getByRole('combobox').click(); + await page.getByRole('option', { name: 'Playwright Adapter' }).click(); + await submit(page); + await expect(page).toHaveURL(/\/home\/bots\?id=bot-1$/); + + // Wait for form to fully load + await expect(page.locator('input[name="name"]')).toHaveValue('Bound Bot'); + + // Find the pipeline select by its label "Bind Pipeline" + const pipelineCard = page.getByText('Bind Pipeline').locator('..'); + await expect(pipelineCard).toBeVisible({ timeout: 5000 }); + + // Click on the select trigger within the pipeline binding card + // The select trigger shows "Select Pipeline" placeholder initially + const pipelineSelectTrigger = page.getByText('Select Pipeline').first(); + await pipelineSelectTrigger.click(); + + // Select the pipeline option + await page.getByRole('option', { name: 'Production Pipeline' }).click(); + + // Save the bot + await save(page); + + // Reload and verify binding persisted + await page.reload(); + // The pipeline name should appear in the select trigger (not in sidebar or options) + await expect( + page + .locator('[data-slot="select-trigger"]') + .filter({ hasText: 'Production Pipeline' }), + ).toBeVisible(); + }); +}); + +test.describe('empty states', () => { + test('shows empty state when no bots exist', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/bots'); + await expect(page.getByText('Select a bot from the sidebar')).toBeVisible(); + }); + + test('shows empty state when no pipelines exist', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/pipelines'); + await expect( + page.getByText('Select a pipeline from the sidebar'), + ).toBeVisible(); + }); + + test('shows empty state when no knowledge bases exist', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/knowledge'); + await expect( + page.getByText('Select a knowledge base from the sidebar'), + ).toBeVisible(); + }); + + test('shows empty state when no MCP servers exist', async ({ page }) => { + await installLangBotApiMocks(page, { authenticated: true }); + + await page.goto('/home/mcp'); + await expect( + page.getByText('Select an MCP server from the sidebar'), + ).toBeVisible(); + }); +}); diff --git a/web/tests/e2e/fixtures/langbot-api.ts b/web/tests/e2e/fixtures/langbot-api.ts index 08f23a5bb..e0be0ebc4 100644 --- a/web/tests/e2e/fixtures/langbot-api.ts +++ b/web/tests/e2e/fixtures/langbot-api.ts @@ -11,7 +11,68 @@ interface SkillMock { updated_at: string; } +interface PipelineMock { + uuid: string; + name: string; + description: string; + config: JsonRecord; + emoji: string; + is_default: boolean; + updated_at: string; +} + +interface KnowledgeBaseMock { + uuid: string; + name: string; + description: string; + emoji: string; + knowledge_engine_plugin_id: string; + creation_settings: JsonRecord; + retrieval_settings: JsonRecord; + knowledge_engine: { + plugin_id: string; + name: { + en_US: string; + zh_Hans: string; + }; + capabilities: string[]; + }; + updated_at: string; +} + +interface MCPServerMock { + name: string; + mode: 'sse' | 'stdio' | 'http'; + enable: boolean; + extra_args: JsonRecord; + runtime_info: { + status: 'connected'; + tool_count: number; + tools: unknown[]; + }; + readme: string; + updated_at: string; +} + +interface BotMock { + uuid: string; + name: string; + description: string; + enable: boolean; + adapter: string; + adapter_config: JsonRecord; + use_pipeline_uuid?: string; + pipeline_routing_rules: unknown[]; + adapter_runtime_values: JsonRecord; + updated_at: string; +} + interface LangBotApiMockState { + bots: BotMock[]; + counters: Record; + knowledgeBases: KnowledgeBaseMock[]; + mcpServers: MCPServerMock[]; + pipelines: PipelineMock[]; skills: SkillMock[]; } @@ -36,6 +97,19 @@ function routePath(route: Route) { return new URL(route.request().url()).pathname; } +function parseJsonBody(route: Route): JsonRecord { + return JSON.parse(route.request().postData() || '{}') as JsonRecord; +} + +function now() { + return new Date().toISOString(); +} + +function nextId(state: LangBotApiMockState, prefix: string) { + state.counters[prefix] = (state.counters[prefix] || 0) + 1; + return `${prefix}-${state.counters[prefix]}`; +} + function emptyMonitoringData() { return { overview: { @@ -93,6 +167,131 @@ function makeSkill(data: JsonRecord): SkillMock { }; } +function makePipeline( + state: LangBotApiMockState, + data: JsonRecord, + uuid = nextId(state, 'pipeline'), +): PipelineMock { + return { + uuid, + name: String(data.name || ''), + description: String(data.description || ''), + config: (data.config as JsonRecord | undefined) || { + ai: {}, + trigger: {}, + safety: {}, + output: {}, + }, + emoji: String(data.emoji || '⚙️'), + is_default: false, + updated_at: now(), + }; +} + +function knowledgeEngine() { + return { + plugin_id: 'builtin/minimal-knowledge', + name: { + en_US: 'Minimal Knowledge Engine', + zh_Hans: '最小知识库引擎', + }, + description: { + en_US: 'Minimal mocked engine for frontend smoke tests.', + zh_Hans: '用于前端冒烟测试的最小模拟引擎。', + }, + capabilities: ['text_retrieval'], + creation_schema: [], + retrieval_schema: [], + }; +} + +function makeKnowledgeBase( + state: LangBotApiMockState, + data: JsonRecord, + uuid = nextId(state, 'knowledge'), +): KnowledgeBaseMock { + const engine = knowledgeEngine(); + return { + uuid, + name: String(data.name || ''), + description: String(data.description || ''), + emoji: String(data.emoji || '📚'), + knowledge_engine_plugin_id: String( + data.knowledge_engine_plugin_id || engine.plugin_id, + ), + creation_settings: (data.creation_settings as JsonRecord | undefined) || {}, + retrieval_settings: + (data.retrieval_settings as JsonRecord | undefined) || {}, + knowledge_engine: { + plugin_id: engine.plugin_id, + name: engine.name, + capabilities: engine.capabilities, + }, + updated_at: now(), + }; +} + +function makeMCPServer(data: JsonRecord): MCPServerMock { + return { + name: String(data.name || ''), + mode: (data.mode as MCPServerMock['mode']) || 'sse', + enable: data.enable !== false, + extra_args: (data.extra_args as JsonRecord | undefined) || {}, + runtime_info: { + status: 'connected', + tool_count: 0, + tools: [], + }, + readme: '', + updated_at: now(), + }; +} + +function makeBot( + state: LangBotApiMockState, + data: JsonRecord, + uuid = nextId(state, 'bot'), +): BotMock { + return { + uuid, + name: String(data.name || ''), + description: String(data.description || ''), + enable: data.enable !== false, + adapter: String(data.adapter || 'playwright-adapter'), + adapter_config: (data.adapter_config as JsonRecord | undefined) || {}, + use_pipeline_uuid: data.use_pipeline_uuid + ? String(data.use_pipeline_uuid) + : undefined, + pipeline_routing_rules: + (data.pipeline_routing_rules as unknown[] | undefined) || [], + adapter_runtime_values: { + webhook_full_url: `https://playwright.test/bots/${uuid}/webhook`, + extra_webhook_full_url: '', + }, + updated_at: now(), + }; +} + +function mockAdapters() { + return [ + { + name: 'playwright-adapter', + label: { + en_US: 'Playwright Adapter', + zh_Hans: 'Playwright 适配器', + }, + description: { + en_US: 'Minimal adapter for frontend E2E tests.', + zh_Hans: '用于前端 E2E 测试的最小适配器。', + }, + spec: { + categories: ['testing'], + config: [], + }, + }, + ]; +} + async function handleBackendApi(route: Route, state: LangBotApiMockState) { const request = route.request(); const url = new URL(request.url()); @@ -147,16 +346,160 @@ async function handleBackendApi(route: Route, state: LangBotApiMockState) { return fulfillJson(route, { credits: null }); } + if (path === '/api/v1/platform/adapters') { + return fulfillJson(route, { adapters: mockAdapters() }); + } + if (path === '/api/v1/platform/bots') { - return fulfillJson(route, { bots: [] }); + if (method === 'POST') { + const bot = makeBot(state, parseJsonBody(route)); + state.bots = [ + ...state.bots.filter((item) => item.uuid !== bot.uuid), + bot, + ]; + return fulfillJson(route, { uuid: bot.uuid }); + } + + return fulfillJson(route, { bots: state.bots }); + } + + const botLogsMatch = path.match(/^\/api\/v1\/platform\/bots\/([^/]+)\/logs$/); + if (botLogsMatch) { + return fulfillJson(route, { logs: [], total: 0 }); + } + + const botMatch = path.match(/^\/api\/v1\/platform\/bots\/([^/]+)$/); + if (botMatch) { + const botId = decodeURIComponent(botMatch[1]); + + if (method === 'PUT') { + const bot = makeBot(state, parseJsonBody(route), botId); + state.bots = [...state.bots.filter((item) => item.uuid !== botId), bot]; + return fulfillJson(route, {}); + } + + if (method === 'DELETE') { + state.bots = state.bots.filter((item) => item.uuid !== botId); + return fulfillJson(route, {}); + } + + const bot = state.bots.find((item) => item.uuid === botId); + return fulfillJson(route, { + bot: bot || makeBot(state, { name: botId }, botId), + }); + } + + if (path === '/api/v1/pipelines/_/metadata') { + return fulfillJson(route, { configs: [] }); } if (path === '/api/v1/pipelines') { - return fulfillJson(route, { pipelines: [] }); + if (method === 'POST') { + const pipeline = makePipeline(state, parseJsonBody(route)); + state.pipelines = [ + ...state.pipelines.filter((item) => item.uuid !== pipeline.uuid), + pipeline, + ]; + return fulfillJson(route, { uuid: pipeline.uuid }); + } + + return fulfillJson(route, { pipelines: state.pipelines }); + } + + const pipelineMatch = path.match(/^\/api\/v1\/pipelines\/([^/]+)$/); + if (pipelineMatch) { + const pipelineId = decodeURIComponent(pipelineMatch[1]); + + if (method === 'PUT') { + const pipeline = makePipeline(state, parseJsonBody(route), pipelineId); + state.pipelines = [ + ...state.pipelines.filter((item) => item.uuid !== pipelineId), + pipeline, + ]; + return fulfillJson(route, {}); + } + + if (method === 'DELETE') { + state.pipelines = state.pipelines.filter( + (item) => item.uuid !== pipelineId, + ); + return fulfillJson(route, {}); + } + + const pipeline = state.pipelines.find((item) => item.uuid === pipelineId); + return fulfillJson(route, { + pipeline: + pipeline || makePipeline(state, { name: pipelineId }, pipelineId), + }); + } + + const pipelineExtensionsMatch = path.match( + /^\/api\/v1\/pipelines\/([^/]+)\/extensions$/, + ); + if (pipelineExtensionsMatch) { + return fulfillJson(route, { + enable_all_plugins: true, + enable_all_mcp_servers: true, + enable_all_skills: true, + bound_plugins: [], + available_plugins: [], + bound_mcp_servers: [], + available_mcp_servers: state.mcpServers, + bound_skills: [], + available_skills: state.skills, + }); } if (path === '/api/v1/knowledge/bases') { - return fulfillJson(route, { bases: [] }); + if (method === 'POST') { + const base = makeKnowledgeBase(state, parseJsonBody(route)); + state.knowledgeBases = [ + ...state.knowledgeBases.filter((item) => item.uuid !== base.uuid), + base, + ]; + return fulfillJson(route, { uuid: base.uuid }); + } + + return fulfillJson(route, { bases: state.knowledgeBases }); + } + + const knowledgeBaseFilesMatch = path.match( + /^\/api\/v1\/knowledge\/bases\/([^/]+)\/files$/, + ); + if (knowledgeBaseFilesMatch) { + return fulfillJson(route, { files: [] }); + } + + const knowledgeBaseMatch = path.match( + /^\/api\/v1\/knowledge\/bases\/([^/]+)$/, + ); + if (knowledgeBaseMatch) { + const baseId = decodeURIComponent(knowledgeBaseMatch[1]); + + if (method === 'PUT') { + const base = makeKnowledgeBase(state, parseJsonBody(route), baseId); + state.knowledgeBases = [ + ...state.knowledgeBases.filter((item) => item.uuid !== baseId), + base, + ]; + return fulfillJson(route, { uuid: base.uuid }); + } + + if (method === 'DELETE') { + state.knowledgeBases = state.knowledgeBases.filter( + (item) => item.uuid !== baseId, + ); + return fulfillJson(route, {}); + } + + const base = state.knowledgeBases.find((item) => item.uuid === baseId); + return fulfillJson(route, { + base: base || makeKnowledgeBase(state, { name: baseId }, baseId), + }); + } + + if (path === '/api/v1/knowledge/engines') { + return fulfillJson(route, { engines: [knowledgeEngine()] }); } if (path === '/api/v1/knowledge/migration/status') { @@ -176,7 +519,60 @@ async function handleBackendApi(route: Route, state: LangBotApiMockState) { } if (path === '/api/v1/mcp/servers') { - return fulfillJson(route, { servers: [] }); + if (method === 'POST') { + const server = makeMCPServer(parseJsonBody(route)); + state.mcpServers = [ + ...state.mcpServers.filter((item) => item.name !== server.name), + server, + ]; + return fulfillJson(route, { task_id: nextId(state, 'task') }); + } + + return fulfillJson(route, { servers: state.mcpServers }); + } + + const mcpTestMatch = path.match(/^\/api\/v1\/mcp\/servers\/([^/]+)\/test$/); + if (mcpTestMatch) { + return fulfillJson(route, { + runtime_info: { + status: 'connected', + tool_count: 0, + tools: [], + }, + }); + } + + const mcpServerMatch = path.match(/^\/api\/v1\/mcp\/servers\/([^/]+)$/); + if (mcpServerMatch) { + const serverName = decodeURIComponent(mcpServerMatch[1]); + + if (method === 'PUT') { + const existing = state.mcpServers.find( + (item) => item.name === serverName, + ); + const server = makeMCPServer({ + ...(existing || {}), + ...parseJsonBody(route), + name: serverName, + }); + state.mcpServers = [ + ...state.mcpServers.filter((item) => item.name !== serverName), + server, + ]; + return fulfillJson(route, { task_id: nextId(state, 'task') }); + } + + if (method === 'DELETE') { + state.mcpServers = state.mcpServers.filter( + (item) => item.name !== serverName, + ); + return fulfillJson(route, { task_id: nextId(state, 'task') }); + } + + const server = state.mcpServers.find((item) => item.name === serverName); + return fulfillJson(route, { + server: server || makeMCPServer({ name: serverName }), + }); } if (path === '/api/v1/skills') { @@ -229,6 +625,23 @@ async function handleBackendApi(route: Route, state: LangBotApiMockState) { const skillMatch = path.match(/^\/api\/v1\/skills\/([^/]+)$/); if (skillMatch) { const skillName = decodeURIComponent(skillMatch[1]); + if (method === 'PUT') { + const skill = makeSkill({ + ...parseJsonBody(route), + name: skillName, + }); + state.skills = [ + ...state.skills.filter((item) => item.name !== skillName), + skill, + ]; + return fulfillJson(route, { skill }); + } + + if (method === 'DELETE') { + state.skills = state.skills.filter((item) => item.name !== skillName); + return fulfillJson(route, {}); + } + const skill = state.skills.find((item) => item.name === skillName) || { name: skillName, display_name: '', @@ -389,6 +802,11 @@ export async function installLangBotApiMocks( ) { const { authenticated = false, storage = {} } = options; const state: LangBotApiMockState = { + bots: [], + counters: {}, + knowledgeBases: [], + mcpServers: [], + pipelines: [], skills: [], }; From a1e6eccdeb1ac6af9441c257a7304ee1432b94ab Mon Sep 17 00:00:00 2001 From: Junyan Chin Date: Thu, 18 Jun 2026 21:40:31 +0800 Subject: [PATCH 12/16] feat(box): bidirectional attachment transfer for sandbox (#2257) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(box): bidirectional attachment transfer for sandbox Materialize inbound attachments into the sandbox workspace so agents can process user-sent files, and collect agent-produced files from the outbox to attach them back to the reply. - box(service): add materialize_inbound_attachments / collect_outbound attachments. Prefer direct host-filesystem read/write on the bind-mounted workspace (no size limit), falling back to chunked exec only for non-shared backends (e2b/remote). Clear per-query inbox/outbox dirs at turn start to avoid query_id-reuse collisions. - provider(localagent): inject inbound attachment descriptors into the sandbox and append a system note telling the agent the inbox/outbox paths. - pipeline(wrapper): collect outbox files on the final stream chunk and append them as attachment components to the response chain. - web(debug-dialog): render File components with a download link when base64/url is present; add base64/path fields to the File entity. - tests: cover inbound/outbound, large-file transfer without truncation, and stale-dir clearing (86 passing). * feat(box): support voice/file attachment round-trip end-to-end Extends the bidirectional attachment transfer to audio and arbitrary files through the real webchat UI, and fixes the model-payload errors that non-image attachments triggered. - platform(websocket_adapter): resolve Voice/File component storage keys to base64 (previously only Image), so audio/documents reach the sandbox inbox. - web(debug-dialog): accept audio/* and any file in the uploader (was image-only), classify by mimetype, upload Voice/File via the documents endpoint, and render non-image staged attachments as a chip. - provider(litellmchat): drop non-image file parts (file_base64 / file_url) when building the OpenAI/LiteLLM payload. These come from Voice/File attachments — including ones replayed from conversation history — and the agent reads their bytes from the sandbox, not the model. Without this the provider rejects the request: 'invalid content type=file_base64'. - provider(localagent): also strip those parts from the current user message alongside the sandbox-path note (model-facing clarity; the requester is the real safety net for history). - tests: cover the requester strip/keep behavior (file dropped, image kept and reshaped to image_url, mixed history, plain-string content). * test(box): cover inbound/outbound attachment helpers; fix ruff format - ruff format localagent.py (CI ruff format --check was failing) - add unit tests for ResponseWrapper outbound-attachment helpers (wrapper.py 78%->98%) - add unit tests for LocalAgentRunner._inject_inbound_attachments - add unit tests for WebSocketAdapter._process_image_components (0%->covered) Lifts PR patch coverage from 68.97% to ~88% (>75% target). --- src/langbot/pkg/box/service.py | 422 ++++++++++++++++++ src/langbot/pkg/pipeline/wrapper/wrapper.py | 57 ++- .../pkg/platform/sources/websocket_adapter.py | 42 +- .../modelmgr/requesters/litellmchat.py | 11 + .../pkg/provider/runners/localagent.py | 68 +++ tests/unit_tests/box/test_box_service.py | 252 +++++++++++ .../test_wrapper_outbound_attachments.py | 146 ++++++ .../test_websocket_adapter_attachments.py | 92 ++++ .../provider/test_litellm_convert_messages.py | 93 ++++ .../test_localagent_inbound_attachments.py | 146 ++++++ .../components/debug-dialog/DebugDialog.tsx | 122 +++-- web/src/app/infra/entities/message/index.ts | 2 + 12 files changed, 1405 insertions(+), 48 deletions(-) create mode 100644 tests/unit_tests/pipeline/test_wrapper_outbound_attachments.py create mode 100644 tests/unit_tests/platform/test_websocket_adapter_attachments.py create mode 100644 tests/unit_tests/provider/test_litellm_convert_messages.py create mode 100644 tests/unit_tests/provider/test_localagent_inbound_attachments.py diff --git a/src/langbot/pkg/box/service.py b/src/langbot/pkg/box/service.py index 0eaa0973b..dc6c317e3 100644 --- a/src/langbot/pkg/box/service.py +++ b/src/langbot/pkg/box/service.py @@ -335,6 +335,428 @@ class BoxService: return await self.execute_spec_payload(spec_payload, query) + # ── Attachment passthrough (inbound / outbound) ────────────────── + # + # IM/webchat attachments (images, voices, files) reach the LLM as + # multimodal content, but historically never landed on the sandbox + # filesystem, so the agent's exec/read/write tools could not operate on + # them. Conversely, files the agent produced inside the sandbox were + # never surfaced back to the user. These two helpers close both gaps: + # + # inbound : message_chain attachments -> /workspace/inbox// + # outbound : /workspace/outbox// -> reply MessageChain + # + # Transfer prefers DIRECT HOST FILESYSTEM access to the bind-mounted + # workspace (default_workspace on the host maps to /workspace inside the + # container), which has no size limit. This covers the local docker / + # nsjail / stdio backends. For backends where the workspace is NOT visible + # on the LangBot host (E2B, an external remote runtime.endpoint), it falls + # back to a base64-through-exec round-trip. The exec channel can only move + # small files reliably — the docker backend passes the command as a single + # argv (ARG_MAX) and exec stdout is truncated by output_limit_chars — so + # the host path is strongly preferred and used whenever available. + + INBOX_MOUNT_DIR = '/workspace/inbox' + OUTBOX_MOUNT_DIR = '/workspace/outbox' + INBOX_SUBDIR = 'inbox' + OUTBOX_SUBDIR = 'outbox' + # Hard cap on a single attachment. The HTTP upload endpoints already cap + # uploads at 10MiB; keep parity. + _ATTACHMENT_MAX_BYTES = 10 * _MIB + # Conservative cap for the exec FALLBACK path only (ARG_MAX / stdout + # truncation). The host-filesystem path has no such limit. + _EXEC_FALLBACK_MAX_BYTES = 256 * 1024 + + def _host_query_dir(self, subdir: str, query_id) -> str | None: + """Host path for ``/workspace//`` when LangBot can + access the bind-mounted workspace directly, else ``None``. + + ``default_workspace`` is the host directory bind-mounted to + ``/workspace`` for the local docker/nsjail backends and shared + outright in stdio mode, so a file written there by LangBot is visible + to the sandbox (and vice-versa). It is ``None`` / not a local dir for + E2B and remote runtimes, where we must fall back to the exec channel. + """ + root = self.default_workspace + if not root or not os.path.isdir(root): + return None + return os.path.join(root, subdir, str(query_id)) + + @staticmethod + def _sanitize_attachment_name(name: str, fallback: str) -> str: + """Reduce an arbitrary attachment name to a safe basename. + + Strips directory separators and parent refs so a crafted file name + can never escape the inbox/outbox directory. + """ + base = os.path.basename(str(name or '').replace('\\', '/').strip()) + base = base.lstrip('.') or '' + # Drop anything that is not a conservative filename charset. + cleaned = ''.join(c for c in base if c.isalnum() or c in ('.', '_', '-', ' ')).strip() + cleaned = cleaned.replace(' ', '_') + return cleaned or fallback + + @staticmethod + async def _component_to_bytes(component) -> tuple[bytes, str] | None: + """Best-effort extraction of (bytes, mime) from a platform component. + + Handles base64, http(s) url and local path sources. Returns None when + no payload can be resolved. + """ + import base64 as _b64 + + b64 = getattr(component, 'base64', None) + if b64: + data = b64 + mime = 'application/octet-stream' + if isinstance(data, str) and data.startswith('data:'): + split_index = data.find(';base64,') + if split_index != -1: + mime = data[5:split_index] + data = data[split_index + 8 :] + try: + return _b64.b64decode(data), mime + except Exception: + return None + + url = getattr(component, 'url', None) + if url: + try: + import httpx + + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get(url) + resp.raise_for_status() + return resp.content, resp.headers.get('Content-Type', 'application/octet-stream') + except Exception: + return None + + path = getattr(component, 'path', None) + if path: + try: + import aiofiles + + async with aiofiles.open(path, 'rb') as f: + return await f.read(), 'application/octet-stream' + except Exception: + return None + + return None + + async def _write_files_into_sandbox( + self, + query: pipeline_query.Query, + subdir: str, + target_mount_dir: str, + files: list[tuple[str, bytes]], + ) -> list[str]: + """Write *files* (name, bytes) into the per-query directory. + + Prefers a direct host-filesystem write to the bind-mounted workspace + (no size limit). Falls back to a base64-through-exec round-trip only + when the workspace is not visible on the LangBot host (E2B / remote). + Returns the list of in-sandbox paths actually written. + """ + if not files: + return [] + + host_dir = self._host_query_dir(subdir, query.query_id) + if host_dir is not None: + return await asyncio.to_thread(self._write_files_host, host_dir, target_mount_dir, files) + + return await self._write_files_via_exec(query, target_mount_dir, files) + + def _write_files_host( + self, + host_dir: str, + target_mount_dir: str, + files: list[tuple[str, bytes]], + ) -> list[str]: + """Write attachments straight onto the bind-mounted host directory. + + Recreates the per-query directory from scratch so a reused query_id + (the webchat session uses small sequential ids) never inherits stale + files from an earlier turn. + """ + import shutil + + shutil.rmtree(host_dir, ignore_errors=True) + os.makedirs(host_dir, exist_ok=True) + written: list[str] = [] + for name, data in files: + with open(os.path.join(host_dir, name), 'wb') as fh: + fh.write(data) + written.append(f'{target_mount_dir}/{name}') + return written + + async def _write_files_via_exec( + self, + query: pipeline_query.Query, + target_dir: str, + files: list[tuple[str, bytes]], + ) -> list[str]: + """Fallback: ship files into the sandbox over the exec channel. + + Only used for backends without host-filesystem access (E2B / remote). + Each file is base64-decoded inside the sandbox. Files larger than the + conservative exec cap are skipped (ARG_MAX / stdout limits). + """ + import base64 as _b64 + import json as _json + + manifest = [] + for name, data in files: + if len(data) > self._EXEC_FALLBACK_MAX_BYTES: + self.ap.logger.warning( + f'Attachment "{name}" ({len(data)} bytes) exceeds the exec-channel ' + f'fallback limit ({self._EXEC_FALLBACK_MAX_BYTES} bytes); skipping. ' + f'Configure a host-shared workspace to transfer large files.' + ) + continue + manifest.append({'name': name, 'b64': _b64.b64encode(data).decode('ascii')}) + if not manifest: + return [] + + manifest_b64 = _b64.b64encode(_json.dumps(manifest).encode('utf-8')).decode('ascii') + script = ( + 'import base64, json, os, shutil\n' + f'target = {target_dir!r}\n' + 'shutil.rmtree(target, ignore_errors=True)\n' + 'os.makedirs(target, exist_ok=True)\n' + f'manifest = json.loads(base64.b64decode({manifest_b64!r}))\n' + 'written = []\n' + 'for item in manifest:\n' + " p = os.path.join(target, item['name'])\n" + " with open(p, 'wb') as f:\n" + " f.write(base64.b64decode(item['b64']))\n" + ' written.append(p)\n' + 'print(json.dumps(written))\n' + ) + result = await self.execute_tool( + {'command': f"python3 - <<'LBPY'\n{script}\nLBPY", 'timeout_sec': 120}, + query, + ) + if not result.get('ok'): + self.ap.logger.warning( + f'Failed to write inbound attachments into sandbox via exec: ' + f'query_id={query.query_id} stderr={result.get("stderr", "")[:200]}' + ) + return [] + try: + return _json.loads(str(result.get('stdout') or '').strip().splitlines()[-1]) + except Exception: + return [] + + async def materialize_inbound_attachments(self, query: pipeline_query.Query) -> list[dict]: + """Persist message-chain attachments into the sandbox inbox. + + Returns a list of ``{path, name, type, size}`` describing what was + written, so the runner can tell the LLM the exact in-sandbox paths. + Returns ``[]`` when sandbox is unavailable or there are no attachments. + """ + if not self._available: + return [] + + import langbot_plugin.api.entities.builtin.platform.message as platform_message + + message_chain = getattr(query, 'message_chain', None) + if not message_chain: + return [] + + type_map = [ + (platform_message.Image, 'Image', 'image', 'png'), + (platform_message.Voice, 'Voice', 'voice', 'wav'), + (platform_message.File, 'File', 'file', 'bin'), + ] + + pending: list[tuple[str, bytes]] = [] + descriptors: list[dict] = [] + index = 0 + for component in message_chain: + matched = None + for cls, kind, prefix, default_ext in type_map: + if isinstance(component, cls): + matched = (kind, prefix, default_ext) + break + if matched is None: + continue + kind, prefix, default_ext = matched + + payload = await self._component_to_bytes(component) + if payload is None: + continue + data, _mime = payload + if not data or len(data) > self._ATTACHMENT_MAX_BYTES: + continue + + index += 1 + raw_name = getattr(component, 'name', None) or f'{prefix}_{index}.{default_ext}' + safe_name = self._sanitize_attachment_name(raw_name, f'{prefix}_{index}.{default_ext}') + pending.append((safe_name, data)) + descriptors.append( + { + 'name': safe_name, + 'type': kind, + 'size': len(data), + } + ) + + if not pending: + return [] + + target_dir = f'{self.INBOX_MOUNT_DIR}/{query.query_id}' + written = await self._write_files_into_sandbox(query, self.INBOX_SUBDIR, target_dir, pending) + written_basenames = {os.path.basename(p) for p in written} + + result: list[dict] = [] + for desc in descriptors: + if desc['name'] in written_basenames: + desc['path'] = f'{target_dir}/{desc["name"]}' + result.append(desc) + if result: + self.ap.logger.info( + f'Materialized {len(result)} inbound attachment(s) into sandbox: ' + f'query_id={query.query_id} dir={target_dir}' + ) + return result + + async def collect_outbound_attachments(self, query: pipeline_query.Query) -> list[dict]: + """Collect files the agent produced in the sandbox outbox. + + Reads ``/workspace/outbox//`` (recursively) — directly from + the bind-mounted host directory when available (no size limit), else + via the exec channel — returns a list of ``{type, name, base64}`` + ready to become platform message components, then clears the outbox so + a later turn in the same session does not re-send stale files. Returns + ``[]`` when nothing was produced. + """ + if not self._available: + return [] + + host_dir = self._host_query_dir(self.OUTBOX_SUBDIR, query.query_id) + if host_dir is not None: + entries = await asyncio.to_thread(self._read_outbox_host, host_dir) + else: + entries = await self._read_outbox_via_exec(query) + + attachments = self._classify_outbound_entries(entries) + + if attachments: + await self._clear_outbox(query, host_dir) + self.ap.logger.info( + f'Collected {len(attachments)} outbound attachment(s) from sandbox: query_id={query.query_id}' + ) + return attachments + + def _read_outbox_host(self, host_dir: str) -> list[dict]: + """Read outbox files straight off the bind-mounted host directory.""" + import base64 as _b64 + + entries: list[dict] = [] + if not os.path.isdir(host_dir): + return entries + for root, _dirs, names in os.walk(host_dir): + for name in sorted(names): + path = os.path.join(root, name) + try: + if os.path.getsize(path) > self._ATTACHMENT_MAX_BYTES: + continue + with open(path, 'rb') as fh: + data = fh.read() + except OSError: + continue + rel = os.path.relpath(path, host_dir) + entries.append({'name': rel, 'b64': _b64.b64encode(data).decode('ascii')}) + return entries + + async def _read_outbox_via_exec(self, query: pipeline_query.Query) -> list[dict]: + """Fallback: read the outbox over the exec channel (E2B / remote). + + Note: exec stdout is truncated by ``output_limit_chars``, so this path + only reliably transfers small files. The host path is preferred. + """ + import json as _json + + target_dir = f'{self.OUTBOX_MOUNT_DIR}/{query.query_id}' + max_bytes = self._EXEC_FALLBACK_MAX_BYTES + script = ( + 'import base64, json, os\n' + f'target = {target_dir!r}\n' + f'max_bytes = {max_bytes}\n' + 'out = []\n' + 'if os.path.isdir(target):\n' + ' for root, _dirs, names in os.walk(target):\n' + ' for n in sorted(names):\n' + ' p = os.path.join(root, n)\n' + ' try:\n' + ' if os.path.getsize(p) > max_bytes:\n' + ' continue\n' + " with open(p, 'rb') as f:\n" + ' data = f.read()\n' + ' except OSError:\n' + ' continue\n' + ' rel = os.path.relpath(p, target)\n' + " out.append({'name': rel, 'b64': base64.b64encode(data).decode('ascii')})\n" + 'print(json.dumps(out))\n' + ) + result = await self.execute_tool( + {'command': f"python3 - <<'LBPY'\n{script}\nLBPY", 'timeout_sec': 120}, + query, + ) + if not result.get('ok'): + return [] + try: + return _json.loads(str(result.get('stdout') or '').strip().splitlines()[-1]) + except Exception: + return [] + + async def _clear_outbox(self, query: pipeline_query.Query, host_dir: str | None) -> None: + """Empty the per-query outbox after collection (host or exec).""" + if host_dir is not None: + import shutil + + def _clear(): + shutil.rmtree(host_dir, ignore_errors=True) + os.makedirs(host_dir, exist_ok=True) + + await asyncio.to_thread(_clear) + return + target_dir = f'{self.OUTBOX_MOUNT_DIR}/{query.query_id}' + await self.execute_tool( + {'command': f'rm -rf {target_dir} && mkdir -p {target_dir}', 'timeout_sec': 30}, + query, + ) + + @staticmethod + def _classify_outbound_entries(entries: list[dict]) -> list[dict]: + """Classify outbox files into Image/Voice/File component descriptors.""" + image_exts = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'bmp'} + voice_exts = {'wav', 'mp3', 'silk', 'amr', 'ogg', 'm4a', 'aac'} + mime_by_ext = { + 'png': 'image/png', + 'jpg': 'image/jpeg', + 'jpeg': 'image/jpeg', + 'gif': 'image/gif', + 'webp': 'image/webp', + 'bmp': 'image/bmp', + } + attachments: list[dict] = [] + for entry in entries or []: + name = str(entry.get('name', '') or '') + b64 = entry.get('b64') + if not name or not b64: + continue + ext = name.rsplit('.', 1)[-1].lower() if '.' in name else '' + base_name = os.path.basename(name) + if ext in image_exts: + mime = mime_by_ext.get(ext, 'image/png') + attachments.append({'type': 'Image', 'name': base_name, 'base64': f'data:{mime};base64,{b64}'}) + elif ext in voice_exts: + attachments.append({'type': 'Voice', 'name': base_name, 'base64': f'data:audio/{ext};base64,{b64}'}) + else: + attachments.append({'type': 'File', 'name': base_name, 'base64': b64}) + return attachments + async def shutdown(self): await self.client.shutdown() diff --git a/src/langbot/pkg/pipeline/wrapper/wrapper.py b/src/langbot/pkg/pipeline/wrapper/wrapper.py index a1ebc97a2..a158c1840 100644 --- a/src/langbot/pkg/pipeline/wrapper/wrapper.py +++ b/src/langbot/pkg/pipeline/wrapper/wrapper.py @@ -7,6 +7,7 @@ from .. import stage import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.events as events @@ -23,6 +24,50 @@ class ResponseWrapper(stage.PipelineStage): async def initialize(self, pipeline_config: dict): pass + def _is_final_assistant_message(self, result) -> bool: + """Whether *result* is the agent's final, tool-call-free answer. + + Intermediate streaming chunks and tool-call rounds must NOT trigger + outbound attachment collection — only the terminal assistant message. + """ + if getattr(result, 'role', None) != 'assistant': + return False + if result.tool_calls: + return False + if isinstance(result, provider_message.MessageChunk): + return bool(result.is_final) + return True + + async def _append_outbound_attachments( + self, + query: pipeline_query.Query, + message_chain: platform_message.MessageChain, + ) -> None: + """Collect sandbox outbox files and append them to *message_chain*. + + Runs at most once per query (guarded by a query variable) and never + raises into the pipeline — attachment delivery is best-effort. + """ + if query.variables.get('_sandbox_outbound_collected'): + return + box_service = getattr(self.ap, 'box_service', None) + if box_service is None or not getattr(box_service, 'available', False): + return + query.variables['_sandbox_outbound_collected'] = True + try: + attachments = await box_service.collect_outbound_attachments(query) + except Exception as e: + self.ap.logger.warning(f'Outbound attachment collection failed: {e}') + return + for att in attachments: + att_type = att.get('type') + if att_type == 'Image': + message_chain.append(platform_message.Image(base64=att['base64'])) + elif att_type == 'Voice': + message_chain.append(platform_message.Voice(base64=att['base64'])) + else: + message_chain.append(platform_message.File(name=att.get('name', 'file'), base64=att['base64'])) + async def process( self, query: pipeline_query.Query, @@ -83,10 +128,16 @@ class ResponseWrapper(stage.PipelineStage): ) else: if event_ctx.event.reply_message_chain is not None: - query.resp_message_chain.append(event_ctx.event.reply_message_chain) - + reply_chain = event_ctx.event.reply_message_chain else: - query.resp_message_chain.append(result.get_content_platform_message_chain()) + reply_chain = result.get_content_platform_message_chain() + + # Attach files the agent produced in the sandbox + # outbox, but only on the terminal assistant message. + if self._is_final_assistant_message(result): + await self._append_outbound_attachments(query, reply_chain) + + query.resp_message_chain.append(reply_chain) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/src/langbot/pkg/platform/sources/websocket_adapter.py b/src/langbot/pkg/platform/sources/websocket_adapter.py index 9ffcf04ac..0574292f3 100644 --- a/src/langbot/pkg/platform/sources/websocket_adapter.py +++ b/src/langbot/pkg/platform/sources/websocket_adapter.py @@ -312,12 +312,18 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) async def _process_image_components(self, message_chain_obj: list): """ - 处理消息链中的图片和文件组件,将path转换为base64 + 处理消息链中的图片、语音和文件组件,将 path 转换为 base64 + + Image / Voice / File components uploaded from the web client carry a + storage key in ``path``. Resolve it to a base64 data URI so downstream + stages (multimodal LLM input and the Box sandbox inbox) have a usable + payload, then drop the now-consumed storage object. Args: message_chain_obj: 消息链对象列表 """ import base64 + import mimetypes storage_mgr = self.ap.storage_mgr @@ -325,31 +331,33 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) comp_type = component.get('type', '') comp_path = component.get('path', '') - if not comp_path: + if not comp_path or comp_type not in ('Image', 'Voice', 'File'): continue - if comp_type == 'Image': - try: - file_content = await storage_mgr.storage_provider.load(comp_path) - base64_str = base64.b64encode(file_content).decode('utf-8') + try: + file_content = await storage_mgr.storage_provider.load(comp_path) + base64_str = base64.b64encode(file_content).decode('utf-8') - file_key = comp_path - if file_key.lower().endswith(('.jpg', '.jpeg')): + lowered = comp_path.lower() + if comp_type == 'Image': + if lowered.endswith(('.jpg', '.jpeg')): mime_type = 'image/jpeg' - elif file_key.lower().endswith('.png'): - mime_type = 'image/png' - elif file_key.lower().endswith('.gif'): + elif lowered.endswith('.gif'): mime_type = 'image/gif' - elif file_key.lower().endswith('.webp'): + elif lowered.endswith('.webp'): mime_type = 'image/webp' else: mime_type = 'image/png' + elif comp_type == 'Voice': + mime_type = mimetypes.guess_type(comp_path)[0] or 'audio/wav' + else: # File + mime_type = mimetypes.guess_type(comp_path)[0] or 'application/octet-stream' - component['base64'] = f'data:{mime_type};base64,{base64_str}' - await storage_mgr.storage_provider.delete(comp_path) - component['path'] = '' - except Exception as e: - await self.logger.error(f'Failed to load image file {comp_path}: {e}') + component['base64'] = f'data:{mime_type};base64,{base64_str}' + await storage_mgr.storage_provider.delete(comp_path) + component['path'] = '' + except Exception as e: + await self.logger.error(f'Failed to load {comp_type} file {comp_path}: {e}') async def handle_websocket_message( self, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 8c750bd7d..d58dd2c5f 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -216,11 +216,22 @@ class LiteLLMRequester(requester.ProviderAPIRequester): content = msg_dict.get('content') if isinstance(content, list): + converted_parts = [] for part in content: if isinstance(part, dict) and part.get('type') == 'image_base64': part['image_url'] = {'url': part['image_base64']} part['type'] = 'image_url' del part['image_base64'] + # OpenAI-compatible chat models reject non-image file parts + # (audio/document base64 or url). These originate from Voice / + # File attachments — including ones replayed from conversation + # history — and the agent already accesses their bytes via the + # sandbox. Drop them from the model payload to avoid + # "Invalid user message ... invalid content type=file_base64". + if isinstance(part, dict) and part.get('type') in ('file_base64', 'file_url'): + continue + converted_parts.append(part) + msg_dict['content'] = converted_parts req_messages.append(msg_dict) diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 9a90ed47d..338de2e59 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -104,6 +104,68 @@ class _StreamAccumulator: class LocalAgentRunner(runner.RequestRunner): """Local agent request runner""" + async def _inject_inbound_attachments( + self, + query: pipeline_query.Query, + user_message: provider_message.Message, + ) -> None: + """Persist inbound attachments into the sandbox and tell the model. + + No-op when the box service is unavailable or there are no attachments. + On success, appends an extra text ContentElement to the user message + listing the in-sandbox paths and the outbox convention, and stashes the + descriptors in ``query.variables['_sandbox_inbound_attachments']``. + """ + box_service = getattr(self.ap, 'box_service', None) + if box_service is None or not getattr(box_service, 'available', False): + return + try: + attachments = await box_service.materialize_inbound_attachments(query) + except Exception as e: # never break the chat turn over attachment IO + self.ap.logger.warning(f'Inbound attachment materialization failed: {e}') + return + if not attachments: + return + + query.variables['_sandbox_inbound_attachments'] = attachments + + lines = [ + 'The user sent attachments. They have been saved into the sandbox and are ' + 'available to the exec/read/write tools at these paths:' + ] + for att in attachments: + lines.append(f'- {att["type"]}: {att["path"]} ({att["size"]} bytes)') + outbox_dir = f'{box_service.OUTBOX_MOUNT_DIR}/{query.query_id}' + lines.append( + 'If you produce any file (image, audio, document, etc.) that should be sent ' + f'back to the user, write it into {outbox_dir}/ (create the directory if ' + 'needed). Every file placed there will be delivered to the user automatically.' + ) + note = '\n'.join(lines) + + # Voice/File attachments are now available to the agent via the sandbox + # (exec/read/write tools). Their raw bytes must NOT be forwarded to the + # chat model as multimodal content: providers reject non-image file + # parts ("Invalid user message ... ensure all user messages are valid + # OpenAI chat completion messages"). Strip those content elements and + # rely on the sandbox-path note instead. Images are kept so vision + # models can still see them. + _model_unsafe_types = {'file_base64', 'file_url'} + if isinstance(user_message.content, list): + user_message.content = [ + ce for ce in user_message.content if getattr(ce, 'type', None) not in _model_unsafe_types + ] + + if isinstance(user_message.content, str): + user_message.content = [ + provider_message.ContentElement.from_text(user_message.content), + provider_message.ContentElement.from_text(note), + ] + elif isinstance(user_message.content, list): + user_message.content.append(provider_message.ContentElement.from_text(note)) + else: + user_message.content = [provider_message.ContentElement.from_text(note)] + def _build_request_messages( self, query: pipeline_query.Query, @@ -232,6 +294,12 @@ class LocalAgentRunner(runner.RequestRunner): user_message = copy.deepcopy(query.user_message) + # Materialize inbound attachments (images / voices / files) into the + # sandbox so the agent's exec/read/write tools can operate on the real + # bytes — not just the multimodal copy the model sees. The exact + # in-sandbox paths are announced to the model as a system note. + await self._inject_inbound_attachments(query, user_message) + user_message_text = '' if isinstance(user_message.content, str): diff --git a/tests/unit_tests/box/test_box_service.py b/tests/unit_tests/box/test_box_service.py index c59a1c5e9..0b7183ba7 100644 --- a/tests/unit_tests/box/test_box_service.py +++ b/tests/unit_tests/box/test_box_service.py @@ -1556,3 +1556,255 @@ class TestBuildSkillExtraMounts: service = BoxService(app, client=Mock(spec=BoxRuntimeClient)) assert service.build_skill_extra_mounts(make_query()) == [] + + +# ── Attachment passthrough (inbound / outbound) ───────────────────────────── + + +class TestAttachmentHelpers: + def test_sanitize_attachment_name_strips_traversal(self): + assert BoxService._sanitize_attachment_name('../../etc/passwd', 'fb') == 'passwd' + assert BoxService._sanitize_attachment_name('/a/b/c.png', 'fb') == 'c.png' + assert BoxService._sanitize_attachment_name('a b c.txt', 'fb') == 'a_b_c.txt' + assert BoxService._sanitize_attachment_name('', 'fallback.bin') == 'fallback.bin' + assert BoxService._sanitize_attachment_name('...', 'fb.bin') == 'fb.bin' + # weird unicode / shell chars dropped, but keeps a usable name + out = BoxService._sanitize_attachment_name('rm -rf $(x).png', 'fb') + assert '/' not in out and '$' not in out and out.endswith('.png') + + def test_classify_outbound_entries_by_extension(self): + entries = [ + {'name': 'chart.png', 'b64': 'AAA'}, + {'name': 'clip.mp3', 'b64': 'BBB'}, + {'name': 'report.pdf', 'b64': 'CCC'}, + {'name': 'sub/dir/photo.JPG', 'b64': 'DDD'}, + {'name': 'noext', 'b64': 'EEE'}, + {'name': 'skip', 'b64': ''}, # dropped (no payload) + ] + out = BoxService._classify_outbound_entries(entries) + by_name = {a['name']: a for a in out} + assert by_name['chart.png']['type'] == 'Image' + assert by_name['chart.png']['base64'].startswith('data:image/png;base64,') + assert by_name['clip.mp3']['type'] == 'Voice' + assert by_name['clip.mp3']['base64'].startswith('data:audio/mp3;base64,') + assert by_name['report.pdf']['type'] == 'File' + assert by_name['report.pdf']['base64'] == 'CCC' # raw b64, no data: prefix + # nested path collapses to basename, case-insensitive ext + assert by_name['photo.JPG']['type'] == 'Image' + assert by_name['noext']['type'] == 'File' + assert 'skip' not in by_name + + @pytest.mark.asyncio + async def test_component_to_bytes_from_data_uri(self): + import base64 + + raw = b'hello-bytes' + data_uri = 'data:text/plain;base64,' + base64.b64encode(raw).decode() + component = SimpleNamespace(base64=data_uri, url=None, path=None) + result = await BoxService._component_to_bytes(component) + assert result is not None + data, mime = result + assert data == raw + assert mime == 'text/plain' + + @pytest.mark.asyncio + async def test_component_to_bytes_returns_none_when_empty(self): + component = SimpleNamespace(base64=None, url=None, path=None) + assert await BoxService._component_to_bytes(component) is None + + +class TestInboundOutboundRoundTrip: + def _service(self) -> BoxService: + service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient)) + service._available = True + return service + + @pytest.mark.asyncio + async def test_materialize_inbound_writes_and_describes(self): + import base64 + + import langbot_plugin.api.entities.builtin.platform.message as platform_message + + service = self._service() + + img_bytes = b'\x89PNG\r\n\x1a\n fake png' + img_b64 = 'data:image/png;base64,' + base64.b64encode(img_bytes).decode() + + query = make_query() + query.message_chain = platform_message.MessageChain( + [ + platform_message.Plain(text='please resize this'), + platform_message.Image(base64=img_b64), + ] + ) + + # Mock the sandbox write path: echo back the written paths. + async def fake_execute_tool(parameters, q): + assert '/workspace/inbox/' in parameters['command'] + return { + 'ok': True, + 'stdout': '["/workspace/inbox/42/image_1.png"]', + 'stderr': '', + } + + service.execute_tool = AsyncMock(side_effect=fake_execute_tool) + + descriptors = await service.materialize_inbound_attachments(query) + assert len(descriptors) == 1 + d = descriptors[0] + assert d['type'] == 'Image' + assert d['path'] == '/workspace/inbox/42/image_1.png' + assert d['size'] == len(img_bytes) + + @pytest.mark.asyncio + async def test_materialize_inbound_noop_without_attachments(self): + import langbot_plugin.api.entities.builtin.platform.message as platform_message + + service = self._service() + query = make_query() + query.message_chain = platform_message.MessageChain([platform_message.Plain(text='just text')]) + service.execute_tool = AsyncMock() + assert await service.materialize_inbound_attachments(query) == [] + service.execute_tool.assert_not_called() + + @pytest.mark.asyncio + async def test_collect_outbound_reads_and_clears(self): + service = self._service() + query = make_query() + + calls = [] + + async def fake_execute_tool(parameters, q): + calls.append(parameters['command']) + if 'os.walk' in parameters['command']: + return { + 'ok': True, + 'stdout': '[{"name": "out.png", "b64": "QUJD"}]', + 'stderr': '', + } + # the rm -rf cleanup call + return {'ok': True, 'stdout': '', 'stderr': ''} + + service.execute_tool = AsyncMock(side_effect=fake_execute_tool) + + attachments = await service.collect_outbound_attachments(query) + assert len(attachments) == 1 + assert attachments[0]['type'] == 'Image' + assert attachments[0]['name'] == 'out.png' + # cleanup (rm -rf) must have been issued after a successful collection + assert any('rm -rf' in c for c in calls) + + @pytest.mark.asyncio + async def test_collect_outbound_empty_no_cleanup(self): + service = self._service() + query = make_query() + + calls = [] + + async def fake_execute_tool(parameters, q): + calls.append(parameters['command']) + return {'ok': True, 'stdout': '[]', 'stderr': ''} + + service.execute_tool = AsyncMock(side_effect=fake_execute_tool) + assert await service.collect_outbound_attachments(query) == [] + assert not any('rm -rf' in c for c in calls) + + @pytest.mark.asyncio + async def test_passthrough_noop_when_unavailable(self): + service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient)) + service._available = False + query = make_query() + assert await service.materialize_inbound_attachments(query) == [] + assert await service.collect_outbound_attachments(query) == [] + + +class TestAttachmentHostPath: + """Direct host-filesystem transfer path (bind-mounted workspace). + + When ``default_workspace`` is a real local dir, inbound/outbound bypass the + exec channel entirely (no ARG_MAX / stdout-truncation limits) and read/write + the bind-mounted host dir directly. + """ + + def _service_with_workspace(self, tmp_path): + ws = str(tmp_path / 'box' / 'default') + os.makedirs(ws, exist_ok=True) + app = make_app(Mock(), allowed_mount_roots=[str(tmp_path)], host_root=str(tmp_path / 'box')) + service = BoxService(app, client=Mock(spec=BoxRuntimeClient)) + service._available = True + # Force the default_workspace to our tmp dir so _host_query_dir resolves. + service.default_workspace = ws + return service, ws + + @pytest.mark.asyncio + async def test_inbound_writes_to_host_no_exec(self, tmp_path): + import base64 + + import langbot_plugin.api.entities.builtin.platform.message as platform_message + + service, ws = self._service_with_workspace(tmp_path) + # Big payload that would blow ARG_MAX on the exec path: + big = b'\x89PNG\r\n\x1a\n' + b'x' * (300 * 1024) + b64 = 'data:image/png;base64,' + base64.b64encode(big).decode() + query = make_query() + query.message_chain = platform_message.MessageChain([platform_message.Image(base64=b64)]) + # execute_tool must NOT be called on the host path. + service.execute_tool = AsyncMock(side_effect=AssertionError('exec must not be used on host path')) + + descriptors = await service.materialize_inbound_attachments(query) + assert len(descriptors) == 1 + d = descriptors[0] + assert d['type'] == 'Image' + assert d['size'] == len(big) + # File actually landed on the host workspace. + host_file = os.path.join(ws, 'inbox', str(query.query_id), d['name']) + assert os.path.isfile(host_file) + assert open(host_file, 'rb').read() == big + + @pytest.mark.asyncio + async def test_inbound_host_clears_stale_query_dir(self, tmp_path): + import base64 + + import langbot_plugin.api.entities.builtin.platform.message as platform_message + + service, ws = self._service_with_workspace(tmp_path) + # Seed a stale file under the same query_id (simulates webchat id reuse). + stale_dir = os.path.join(ws, 'inbox', '42') + os.makedirs(stale_dir, exist_ok=True) + open(os.path.join(stale_dir, 'image_1.png'), 'wb').write(b'STALE-OLD-IMAGE') + + new = b'\x89PNG\r\n\x1a\n NEW' + b64 = 'data:image/png;base64,' + base64.b64encode(new).decode() + query = make_query(query_id=42) + query.message_chain = platform_message.MessageChain([platform_message.Image(base64=b64)]) + service.execute_tool = AsyncMock() + descriptors = await service.materialize_inbound_attachments(query) + # The new write recreated the dir; the stale file is gone, new bytes present. + host_file = os.path.join(stale_dir, descriptors[0]['name']) + assert open(host_file, 'rb').read() == new + # No leftover content from the stale image. + assert b'STALE-OLD-IMAGE' not in open(host_file, 'rb').read() + + @pytest.mark.asyncio + async def test_outbound_reads_host_and_clears(self, tmp_path): + service, ws = self._service_with_workspace(tmp_path) + query = make_query() + outbox = os.path.join(ws, 'outbox', str(query.query_id)) + os.makedirs(outbox, exist_ok=True) + # A large file that would be truncated on the exec/stdout path: + big_png = b'\x89PNG\r\n\x1a\n' + b'y' * (400 * 1024) + open(os.path.join(outbox, 'result.png'), 'wb').write(big_png) + open(os.path.join(outbox, 'notes.txt'), 'wb').write(b'hello') + + service.execute_tool = AsyncMock(side_effect=AssertionError('exec must not be used on host path')) + attachments = await service.collect_outbound_attachments(query) + by_name = {a['name']: a for a in attachments} + assert by_name['result.png']['type'] == 'Image' + assert by_name['notes.txt']['type'] == 'File' + # Full image survived (no truncation). + import base64 + + raw = base64.b64decode(by_name['result.png']['base64'].split(',', 1)[-1]) + assert raw == big_png + # Outbox cleared after collection. + assert os.listdir(outbox) == [] diff --git a/tests/unit_tests/pipeline/test_wrapper_outbound_attachments.py b/tests/unit_tests/pipeline/test_wrapper_outbound_attachments.py new file mode 100644 index 000000000..8fc000bf5 --- /dev/null +++ b/tests/unit_tests/pipeline/test_wrapper_outbound_attachments.py @@ -0,0 +1,146 @@ +"""Unit tests for ResponseWrapper outbound-attachment helpers. + +Covers the sandbox -> user attachment path added for the Box attachment +round-trip: + +* ``_is_final_assistant_message`` — only the terminal, tool-call-free assistant + message (or a final MessageChunk) should trigger collection. +* ``_append_outbound_attachments`` — collects sandbox outbox files exactly once + per query and maps each descriptor to the right platform component, swallowing + collection errors. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.provider.message as provider_message + +from langbot.pkg.pipeline.wrapper.wrapper import ResponseWrapper + + +def _make_wrapper(box_service) -> ResponseWrapper: + app = SimpleNamespace(logger=Mock()) + wrapper = ResponseWrapper.__new__(ResponseWrapper) + wrapper.ap = app + return wrapper + + +def _make_query(): + return SimpleNamespace(variables={}) + + +def test_is_final_assistant_message_plain_assistant(): + wrapper = _make_wrapper(box_service=None) + msg = provider_message.Message(role='assistant', content='done') + assert wrapper._is_final_assistant_message(msg) is True + + +def test_is_final_assistant_message_rejects_non_assistant(): + wrapper = _make_wrapper(box_service=None) + msg = provider_message.Message(role='tool', content='{}') + assert wrapper._is_final_assistant_message(msg) is False + + +def test_is_final_assistant_message_rejects_tool_call_round(): + wrapper = _make_wrapper(box_service=None) + msg = provider_message.Message( + role='assistant', + content='calling', + tool_calls=[ + provider_message.ToolCall( + id='c1', + type='function', + function=provider_message.FunctionCall(name='exec', arguments='{}'), + ) + ], + ) + assert wrapper._is_final_assistant_message(msg) is False + + +def test_is_final_assistant_message_non_final_chunk(): + wrapper = _make_wrapper(box_service=None) + chunk = provider_message.MessageChunk(role='assistant', content='partial', is_final=False) + assert wrapper._is_final_assistant_message(chunk) is False + + final_chunk = provider_message.MessageChunk(role='assistant', content='partial', is_final=True) + assert wrapper._is_final_assistant_message(final_chunk) is True + + +@pytest.mark.asyncio +async def test_append_outbound_attachments_maps_each_type(): + box_service = SimpleNamespace( + available=True, + collect_outbound_attachments=AsyncMock( + return_value=[ + {'type': 'Image', 'base64': 'data:image/png;base64,iVBORw0K'}, + {'type': 'Voice', 'base64': 'data:audio/wav;base64,UklGRg=='}, + {'type': 'File', 'name': 'report.xlsx', 'base64': 'data:app;base64,UEsDBA=='}, + ] + ), + ) + wrapper = _make_wrapper(box_service) + wrapper.ap.box_service = box_service + query = _make_query() + chain = platform_message.MessageChain([]) + + await wrapper._append_outbound_attachments(query, chain) + + kinds = [type(c).__name__ for c in chain] + assert kinds == ['Image', 'Voice', 'File'] + assert query.variables['_sandbox_outbound_collected'] is True + # File keeps its name + file_comp = chain[2] + assert getattr(file_comp, 'name', None) == 'report.xlsx' + + +@pytest.mark.asyncio +async def test_append_outbound_attachments_runs_once_per_query(): + box_service = SimpleNamespace( + available=True, + collect_outbound_attachments=AsyncMock(return_value=[]), + ) + wrapper = _make_wrapper(box_service) + wrapper.ap.box_service = box_service + query = _make_query() + query.variables['_sandbox_outbound_collected'] = True + chain = platform_message.MessageChain([]) + + await wrapper._append_outbound_attachments(query, chain) + + box_service.collect_outbound_attachments.assert_not_awaited() + assert len(chain) == 0 + + +@pytest.mark.asyncio +async def test_append_outbound_attachments_noop_without_box_service(): + wrapper = _make_wrapper(box_service=None) + wrapper.ap.box_service = None + query = _make_query() + chain = platform_message.MessageChain([]) + + await wrapper._append_outbound_attachments(query, chain) + assert len(chain) == 0 + # not marked collected, since service is unavailable + assert '_sandbox_outbound_collected' not in query.variables + + +@pytest.mark.asyncio +async def test_append_outbound_attachments_swallows_collection_error(): + box_service = SimpleNamespace( + available=True, + collect_outbound_attachments=AsyncMock(side_effect=RuntimeError('boom')), + ) + wrapper = _make_wrapper(box_service) + wrapper.ap.box_service = box_service + query = _make_query() + chain = platform_message.MessageChain([]) + + # must not raise + await wrapper._append_outbound_attachments(query, chain) + assert len(chain) == 0 + wrapper.ap.logger.warning.assert_called_once() diff --git a/tests/unit_tests/platform/test_websocket_adapter_attachments.py b/tests/unit_tests/platform/test_websocket_adapter_attachments.py new file mode 100644 index 000000000..18138383d --- /dev/null +++ b/tests/unit_tests/platform/test_websocket_adapter_attachments.py @@ -0,0 +1,92 @@ +"""Unit tests for WebSocketAdapter._process_image_components. + +The web debug client uploads Image / Voice / File components carrying a storage +key in ``path``. This helper resolves each to a base64 data URI (so multimodal +LLM input and the Box sandbox inbox have usable bytes), then deletes the +consumed storage object and clears ``path``. Covers mimetype selection per +type and graceful error handling. +""" + +from __future__ import annotations + +import base64 +from unittest.mock import AsyncMock, Mock + +import pytest + +from langbot.pkg.platform.sources.websocket_adapter import WebSocketAdapter + + +def _make_adapter(load_return=b'hello', load_side_effect=None): + provider = Mock() + provider.load = AsyncMock(return_value=load_return, side_effect=load_side_effect) + provider.delete = AsyncMock() + ap = Mock() + ap.storage_mgr.storage_provider = provider + logger = Mock() + logger.error = AsyncMock() + # WebSocketAdapter is a pydantic model; bypass full __init__/validation. + adapter = WebSocketAdapter.model_construct(ap=ap, logger=logger) + return adapter, provider + + +@pytest.mark.asyncio +async def test_image_jpeg_mimetype_and_cleanup(): + adapter, provider = _make_adapter(load_return=b'\xff\xd8\xff') + chain = [{'type': 'Image', 'path': 'storage://abc/photo.jpg'}] + + await adapter._process_image_components(chain) + + expected_b64 = base64.b64encode(b'\xff\xd8\xff').decode('utf-8') + assert chain[0]['base64'] == f'data:image/jpeg;base64,{expected_b64}' + assert chain[0]['path'] == '' # consumed + provider.delete.assert_awaited_once_with('storage://abc/photo.jpg') + + +@pytest.mark.asyncio +async def test_image_defaults_to_png(): + adapter, _ = _make_adapter() + chain = [{'type': 'Image', 'path': 'storage://abc/blob'}] + await adapter._process_image_components(chain) + assert chain[0]['base64'].startswith('data:image/png;base64,') + + +@pytest.mark.asyncio +async def test_voice_uses_guessed_or_wav_mimetype(): + adapter, _ = _make_adapter() + chain = [{'type': 'Voice', 'path': 'storage://abc/clip.wav'}] + await adapter._process_image_components(chain) + assert chain[0]['base64'].startswith('data:audio/') + + +@pytest.mark.asyncio +async def test_file_uses_octet_stream_fallback(): + adapter, _ = _make_adapter() + chain = [{'type': 'File', 'path': 'storage://abc/unknownblob'}] + await adapter._process_image_components(chain) + assert chain[0]['base64'].startswith('data:application/octet-stream;base64,') + + +@pytest.mark.asyncio +async def test_skips_components_without_path_or_unknown_type(): + adapter, provider = _make_adapter() + chain = [ + {'type': 'Image', 'path': ''}, # no path + {'type': 'Plain', 'path': 'storage://abc/x'}, # not a file component + {'type': 'At', 'target': '123'}, # no path key at all + ] + await adapter._process_image_components(chain) + provider.load.assert_not_awaited() + assert 'base64' not in chain[0] + assert 'base64' not in chain[1] + + +@pytest.mark.asyncio +async def test_load_failure_is_logged_not_raised(): + adapter, _ = _make_adapter(load_side_effect=RuntimeError('storage down')) + chain = [{'type': 'File', 'path': 'storage://abc/doc.pdf'}] + + # must not raise + await adapter._process_image_components(chain) + assert 'base64' not in chain[0] + adapter.logger.error.assert_awaited_once() diff --git a/tests/unit_tests/provider/test_litellm_convert_messages.py b/tests/unit_tests/provider/test_litellm_convert_messages.py new file mode 100644 index 000000000..87ad2e027 --- /dev/null +++ b/tests/unit_tests/provider/test_litellm_convert_messages.py @@ -0,0 +1,93 @@ +"""Unit tests for LiteLLMRequester._convert_messages. + +Focus: the content-part normalization that (a) converts image_base64 parts to +the OpenAI image_url shape and (b) drops non-image file parts (file_base64 / +file_url) which OpenAI-compatible chat models reject. The latter is essential +for Voice/File attachments — including ones replayed from conversation history — +since the agent consumes their bytes via the sandbox, not the model payload. +""" + +import langbot_plugin.api.entities.builtin.provider.message as provider_message + +from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester + + +def _make_requester() -> LiteLLMRequester: + # _convert_messages does not touch instance config, so bypass __init__. + return LiteLLMRequester.__new__(LiteLLMRequester) + + +def test_convert_messages_drops_file_base64_part(): + req = _make_requester() + msg = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('analyze this audio'), + provider_message.ContentElement.from_file_base64('data:audio/wav;base64,AAAA', 'voice.wav'), + ], + ) + out = req._convert_messages([msg]) + parts = out[0]['content'] + types = [p.get('type') for p in parts] + assert 'file_base64' not in types + assert types == ['text'] + assert parts[0]['text'] == 'analyze this audio' + + +def test_convert_messages_drops_file_url_part(): + req = _make_requester() + msg = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('here is a doc'), + provider_message.ContentElement.from_file_url('http://example.com/report.xlsx', 'report.xlsx'), + ], + ) + out = req._convert_messages([msg]) + types = [p.get('type') for p in out[0]['content']] + assert types == ['text'] + + +def test_convert_messages_keeps_image_and_converts_to_image_url(): + req = _make_requester() + msg = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('look'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,AAAA'), + ], + ) + out = req._convert_messages([msg]) + parts = out[0]['content'] + types = [p.get('type') for p in parts] + # image is preserved and reshaped to the OpenAI image_url form + assert types == ['text', 'image_url'] + img_part = parts[1] + assert img_part['image_url'] == {'url': 'data:image/png;base64,AAAA'} + assert 'image_base64' not in img_part + + +def test_convert_messages_mixed_history_strips_only_files(): + req = _make_requester() + # Simulate replayed history: an old voice turn + a current text turn. + history_voice = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('old audio turn'), + provider_message.ContentElement.from_file_base64('data:audio/wav;base64,BBBB', 'voice.wav'), + ], + ) + current = provider_message.Message( + role='user', + content=[provider_message.ContentElement.from_text('now do the csv')], + ) + out = req._convert_messages([history_voice, current]) + assert [p.get('type') for p in out[0]['content']] == ['text'] + assert [p.get('type') for p in out[1]['content']] == ['text'] + + +def test_convert_messages_plain_string_content_untouched(): + req = _make_requester() + msg = provider_message.Message(role='user', content='just text') + out = req._convert_messages([msg]) + assert out[0]['content'] == 'just text' diff --git a/tests/unit_tests/provider/test_localagent_inbound_attachments.py b/tests/unit_tests/provider/test_localagent_inbound_attachments.py new file mode 100644 index 000000000..bc7352f13 --- /dev/null +++ b/tests/unit_tests/provider/test_localagent_inbound_attachments.py @@ -0,0 +1,146 @@ +"""Unit tests for LocalAgentRunner._inject_inbound_attachments. + +Covers the user -> sandbox attachment path added for the Box attachment +round-trip: + +* materialized descriptors are stashed on the query and described to the model + via an appended text note (in-sandbox paths + outbox convention); +* non-image file parts (file_base64 / file_url) are stripped from the user + message content because OpenAI-compatible chat models reject them, while + image and text parts are kept for vision models; +* the helper is a no-op when the box service is unavailable or yields nothing, + and never raises into the chat turn on materialization failure. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +import langbot_plugin.api.entities.builtin.provider.message as provider_message + +from langbot.pkg.provider.runners.localagent import LocalAgentRunner + + +def _make_runner(box_service) -> LocalAgentRunner: + runner = LocalAgentRunner.__new__(LocalAgentRunner) + runner.ap = SimpleNamespace(logger=Mock(), box_service=box_service) + return runner + + +def _make_query(): + return SimpleNamespace(variables={}, query_id='q-123') + + +def _box_service(attachments): + svc = SimpleNamespace( + available=True, + OUTBOX_MOUNT_DIR='/outbox', + materialize_inbound_attachments=AsyncMock(return_value=attachments), + ) + return svc + + +@pytest.mark.asyncio +async def test_inject_strips_file_parts_and_appends_note(): + box = _box_service([{'type': 'Voice', 'path': '/inbox/q-123/voice.wav', 'size': 176000}]) + runner = _make_runner(box) + query = _make_query() + user_message = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('transcribe this'), + provider_message.ContentElement.from_file_base64('data:audio/wav;base64,AAAA', 'voice.wav'), + ], + ) + + await runner._inject_inbound_attachments(query, user_message) + + types = [getattr(ce, 'type', None) for ce in user_message.content] + # file_base64 dropped; text kept; sandbox-path note appended as text + assert 'file_base64' not in types + assert types.count('text') == 2 + note = user_message.content[-1].text + assert '/inbox/q-123/voice.wav' in note + assert '/outbox/q-123' in note + # descriptors stashed for downstream stages + assert query.variables['_sandbox_inbound_attachments'] == box.materialize_inbound_attachments.return_value + + +@pytest.mark.asyncio +async def test_inject_keeps_image_parts(): + box = _box_service([{'type': 'Image', 'path': '/inbox/q-123/pic.png', 'size': 1234}]) + runner = _make_runner(box) + query = _make_query() + user_message = provider_message.Message( + role='user', + content=[ + provider_message.ContentElement.from_text('what is this'), + provider_message.ContentElement.from_image_base64('data:image/png;base64,iVBORw0K'), + ], + ) + + await runner._inject_inbound_attachments(query, user_message) + + types = [getattr(ce, 'type', None) for ce in user_message.content] + assert 'image_base64' in types # vision part preserved + assert types[-1] == 'text' # note appended last + + +@pytest.mark.asyncio +async def test_inject_promotes_string_content_to_list_with_note(): + box = _box_service([{'type': 'File', 'path': '/inbox/q-123/data.csv', 'size': 42}]) + runner = _make_runner(box) + query = _make_query() + user_message = provider_message.Message(role='user', content='clean this csv') + + await runner._inject_inbound_attachments(query, user_message) + + assert isinstance(user_message.content, list) + assert [getattr(ce, 'type', None) for ce in user_message.content] == ['text', 'text'] + assert user_message.content[0].text == 'clean this csv' + assert '/inbox/q-123/data.csv' in user_message.content[1].text + + +@pytest.mark.asyncio +async def test_inject_noop_without_box_service(): + runner = _make_runner(box_service=None) + query = _make_query() + user_message = provider_message.Message(role='user', content='hello') + + await runner._inject_inbound_attachments(query, user_message) + + assert user_message.content == 'hello' + assert '_sandbox_inbound_attachments' not in query.variables + + +@pytest.mark.asyncio +async def test_inject_noop_when_no_attachments(): + box = _box_service([]) + runner = _make_runner(box) + query = _make_query() + user_message = provider_message.Message(role='user', content='hello') + + await runner._inject_inbound_attachments(query, user_message) + + assert user_message.content == 'hello' + assert '_sandbox_inbound_attachments' not in query.variables + + +@pytest.mark.asyncio +async def test_inject_swallows_materialization_error(): + box = SimpleNamespace( + available=True, + OUTBOX_MOUNT_DIR='/outbox', + materialize_inbound_attachments=AsyncMock(side_effect=RuntimeError('disk full')), + ) + runner = _make_runner(box) + query = _make_query() + user_message = provider_message.Message(role='user', content='hello') + + # must not raise + await runner._inject_inbound_attachments(query, user_message) + assert user_message.content == 'hello' + runner.ap.logger.warning.assert_called_once() diff --git a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx index 318dcc7b9..b45e87dd1 100644 --- a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx +++ b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx @@ -15,6 +15,7 @@ import { At, Quote, Voice, + File as FileComponent, Source, } from '@/app/infra/entities/message'; import { toast } from 'sonner'; @@ -64,7 +65,12 @@ export default function DebugDialog({ const [isHovering, setIsHovering] = useState(false); const [isConnected, setIsConnected] = useState(false); const [selectedImages, setSelectedImages] = useState< - Array<{ file: File; preview: string; fileKey?: string }> + Array<{ + file: File; + preview: string; + fileKey?: string; + kind: 'image' | 'voice' | 'file'; + }> >([]); const [isUploading, setIsUploading] = useState(false); const [previewImageUrl, setPreviewImageUrl] = useState(''); @@ -292,23 +298,38 @@ export default function DebugDialog({ const files = e.target.files; if (!files || files.length === 0) return; - const newImages: Array<{ file: File; preview: string }> = []; + const newImages: Array<{ + file: File; + preview: string; + kind: 'image' | 'voice' | 'file'; + }> = []; for (let i = 0; i < files.length; i++) { const file = files[i]; if (file.type.startsWith('image/')) { - const preview = URL.createObjectURL(file); - newImages.push({ file, preview }); + newImages.push({ + file, + preview: URL.createObjectURL(file), + kind: 'image', + }); + } else if (file.type.startsWith('audio/')) { + newImages.push({ file, preview: '', kind: 'voice' }); + } else { + newImages.push({ file, preview: '', kind: 'file' }); } } setSelectedImages((prev) => [...prev, ...newImages]); + // reset the input so selecting the same file again re-triggers onChange + e.target.value = ''; }; const handleRemoveImage = (index: number) => { setSelectedImages((prev) => { const newImages = [...prev]; - URL.revokeObjectURL(newImages[index].preview); + if (newImages[index].preview) { + URL.revokeObjectURL(newImages[index].preview); + } newImages.splice(index, 1); return newImages; }); @@ -372,19 +393,33 @@ export default function DebugDialog({ }); } - // Upload images and add to message chain - for (const image of selectedImages) { + // Upload attachments and add to message chain + for (const attachment of selectedImages) { try { - const result = await httpClient.uploadWebSocketImage( - selectedPipelineId, - image.file, - ); - messageChain.push({ - type: 'Image', - path: result.file_key, - }); + if (attachment.kind === 'image') { + const result = await httpClient.uploadWebSocketImage( + selectedPipelineId, + attachment.file, + ); + messageChain.push({ + type: 'Image', + path: result.file_key, + }); + } else { + // Voice / File go through the generic document upload endpoint, + // which returns a storage key the backend resolves into the + // sandbox inbox just like images. + const result = await httpClient.uploadDocumentFile(attachment.file); + messageChain.push({ + type: attachment.kind === 'voice' ? 'Voice' : 'File', + path: result.file_id, + ...(attachment.kind === 'file' + ? { name: attachment.file.name } + : {}), + }); + } } catch (error) { - console.error('Image upload failed:', error); + console.error('Attachment upload failed:', error); toast.error(t('pipelines.debugDialog.imageUploadFailed')); } } @@ -393,7 +428,9 @@ export default function DebugDialog({ setInputValue(''); setHasAt(false); setQuotedMessage(null); - selectedImages.forEach((img) => URL.revokeObjectURL(img.preview)); + selectedImages.forEach((img) => { + if (img.preview) URL.revokeObjectURL(img.preview); + }); setSelectedImages([]); // Send message via WebSocket @@ -460,13 +497,29 @@ export default function DebugDialog({ } case 'File': { - const file = component as MessageChainComponent & { name?: string }; + const file = component as FileComponent; + const downloadHref = file.base64 + ? file.base64.startsWith('data:') + ? file.base64 + : `data:application/octet-stream;base64,${file.base64}` + : file.url || ''; + const fileName = file.name || 'Unknown'; return (
- - [{t('pipelines.debugDialog.file')}] {file.name || 'Unknown'} - + {downloadHref ? ( + + [{t('pipelines.debugDialog.file')}] {fileName} + + ) : ( + + [{t('pipelines.debugDialog.file')}] {fileName} + + )}
); } @@ -844,17 +897,30 @@ export default function DebugDialog({
)} - {/* Image preview area */} + {/* Attachment preview area */} {selectedImages.length > 0 && (
{selectedImages.map((image, index) => (
- {`preview-${index}`} + {image.kind === 'image' ? ( + {`preview-${index}`} + ) : ( +
+ {image.kind === 'voice' ? ( + + ) : ( + + )} + + {image.file.name} + +
+ )} +
{opt.has_input && selected && ( Date: Thu, 18 Jun 2026 14:06:04 +0000 Subject: [PATCH 16/16] Harden agent runner tool runtimes (#2247) * fix(tools): harden agent runner tool runtimes * fix(tools): bootstrap Python workspaces with available interpreter * fix(tools): clear stale Python workspace env locks * fix(tools): decouple runtime from agent runner * test(tools): cover runtime hardening edge cases * fix(tools): support binary workspace file chunks --- src/langbot/pkg/box/workspace.py | 23 +- .../provider/tools/loaders/availability.py | 18 + .../pkg/provider/tools/loaders/mcp_stdio.py | 97 ++- .../pkg/provider/tools/loaders/native.py | 603 +++++++++++++++--- .../pkg/provider/tools/loaders/skill.py | 39 ++ .../provider/tools/loaders/skill_authoring.py | 75 +-- tests/unit_tests/box/test_workspace.py | 4 +- .../provider/test_mcp_box_integration.py | 80 ++- tests/unit_tests/provider/test_skill_tools.py | 33 +- .../provider/test_tool_manager_native.py | 205 ++++++ 10 files changed, 1008 insertions(+), 169 deletions(-) create mode 100644 src/langbot/pkg/provider/tools/loaders/availability.py diff --git a/src/langbot/pkg/box/workspace.py b/src/langbot/pkg/box/workspace.py index 948622efb..26d1a41e9 100644 --- a/src/langbot/pkg/box/workspace.py +++ b/src/langbot/pkg/box/workspace.py @@ -146,13 +146,19 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace' _LB_PIP_CACHE_DIR="{mount_path}/.cache/pip" mkdir -p "$_LB_META_DIR" "$_LB_TMP_DIR" "$_LB_PIP_CACHE_DIR" + _LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)" + if [ -z "$_LB_SYSTEM_PYTHON" ]; then + echo "python3 or python is required to prepare the workspace Python environment" >&2 + exit 127 + fi + export TMPDIR="$_LB_TMP_DIR" export TEMP="$_LB_TMP_DIR" export TMP="$_LB_TMP_DIR" export PIP_CACHE_DIR="$_LB_PIP_CACHE_DIR" _lb_python_meta() {{ - python - <<'PY' + "$_LB_SYSTEM_PYTHON" - <<'PY' import hashlib import json import os @@ -201,15 +207,26 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace' _LB_LOCK_WAIT=0 while ! mkdir "$_LB_LOCK_DIR" 2>/dev/null; do if [ "$_LB_LOCK_WAIT" -ge 120 ]; then + _LB_LOCK_OWNER="$(cat "$_LB_LOCK_DIR/pid" 2>/dev/null || true)" + if [ -n "$_LB_LOCK_OWNER" ] && kill -0 "$_LB_LOCK_OWNER" 2>/dev/null; then + echo "Timed out waiting for active Python environment lock: $_LB_LOCK_DIR" >&2 + exit 1 + fi + echo "Timed out waiting for Python environment lock, clearing stale lock: $_LB_LOCK_DIR" >&2 + rm -rf "$_LB_LOCK_DIR" 2>/dev/null || true + if mkdir "$_LB_LOCK_DIR" 2>/dev/null; then + break + fi echo "Timed out waiting for Python environment lock: $_LB_LOCK_DIR" >&2 exit 1 fi sleep 1 _LB_LOCK_WAIT=$((_LB_LOCK_WAIT + 1)) done + printf '%s\\n' "$$" > "$_LB_LOCK_DIR/pid" 2>/dev/null || true _lb_cleanup_lock() {{ - rmdir "$_LB_LOCK_DIR" >/dev/null 2>&1 || true + rm -rf "$_LB_LOCK_DIR" >/dev/null 2>&1 || true }} trap _lb_cleanup_lock EXIT INT TERM @@ -225,7 +242,7 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace' if [ "$_LB_NEEDS_BOOTSTRAP" -eq 1 ]; then rm -rf "$_LB_VENV_DIR" - python -m venv "$_LB_VENV_DIR" + "$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR" . "$_LB_VENV_DIR/bin/activate" python -m pip install --upgrade pip setuptools wheel if [ -f "{mount_path}/requirements.txt" ]; then diff --git a/src/langbot/pkg/provider/tools/loaders/availability.py b/src/langbot/pkg/provider/tools/loaders/availability.py new file mode 100644 index 000000000..58d795864 --- /dev/null +++ b/src/langbot/pkg/provider/tools/loaders/availability.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any + + +async def is_box_backend_available(ap: Any) -> bool: + """Return whether the configured Box backend is ready for tool execution.""" + box_service = getattr(ap, 'box_service', None) + if box_service is None: + return False + if not getattr(box_service, 'available', False): + return False + try: + status = await box_service.get_status() + backend_info = status.get('backend', {}) + return bool(backend_info.get('available', False)) + except Exception: + return False diff --git a/src/langbot/pkg/provider/tools/loaders/mcp_stdio.py b/src/langbot/pkg/provider/tools/loaders/mcp_stdio.py index ff607e661..dcfbb9132 100644 --- a/src/langbot/pkg/provider/tools/loaders/mcp_stdio.py +++ b/src/langbot/pkg/provider/tools/loaders/mcp_stdio.py @@ -5,6 +5,8 @@ import asyncio import os import shutil import shlex +import threading +from contextlib import suppress from typing import TYPE_CHECKING, Any import pydantic @@ -18,12 +20,26 @@ from ....box.workspace import ( rewrite_mounted_path, rewrite_venv_command, unwrap_venv_path, + wrap_python_command_with_env, ) if TYPE_CHECKING: from .mcp import RuntimeMCPSession +_WORKSPACE_COPY_LOCKS: dict[str, threading.Lock] = {} +_WORKSPACE_COPY_LOCKS_GUARD = threading.Lock() + + +def _workspace_copy_lock(path: str) -> threading.Lock: + with _WORKSPACE_COPY_LOCKS_GUARD: + lock = _WORKSPACE_COPY_LOCKS.get(path) + if lock is None: + lock = threading.Lock() + _WORKSPACE_COPY_LOCKS[path] = lock + return lock + + class MCPSessionErrorPhase(enum.Enum): """Which phase of the MCP lifecycle failed.""" @@ -49,7 +65,7 @@ class MCPServerBoxConfig(pydantic.BaseModel): host_path: str | None = None host_path_mode: str = 'ro' # MCP servers default to read-write mount only when explicitly requested env: dict[str, str] = pydantic.Field(default_factory=dict) - startup_timeout_sec: int = 120 # Longer default to allow dependency bootstrap + startup_timeout_sec: int = 300 # First Docker bootstrap may need to build a venv and install MCP deps. cpus: float | None = None memory_mb: int | None = None pids_limit: int | None = None @@ -128,6 +144,7 @@ class BoxStdioSessionRuntime: workspace = self._build_workspace(host_path=None) host_path = self.resolve_host_path() process_cwd = '/workspace' + install_cmd: str | None = None try: await workspace.create_session() @@ -168,6 +185,8 @@ class BoxStdioSessionRuntime: env=self.server_config.get('env', {}), cwd=process_cwd, ) + if install_cmd: + payload = self._wrap_process_payload_with_python_env(payload, process_cwd) payload['process_id'] = self.process_id await workspace.box_service.start_managed_process(workspace.session_id, payload) except Exception: @@ -253,14 +272,42 @@ class BoxStdioSessionRuntime: @staticmethod def _copy_workspace_tree(source_path: str, process_host_root: str, process_host_workspace: str) -> None: - shutil.rmtree(process_host_root, ignore_errors=True) - os.makedirs(process_host_root, exist_ok=True) - shutil.copytree( - source_path, - process_host_workspace, - symlinks=True, - ignore=shutil.ignore_patterns('.git', '__pycache__', '.pytest_cache', '.mypy_cache', '.ruff_cache'), - ) + # Docker-backed bootstrap writes root-owned runtime directories such as + # .venv/.tmp into the staged workspace. The host process may not be able + # to delete them, so refresh source files in place and preserve runtime + # directories instead of rmtree'ing the whole staging root. + with _workspace_copy_lock(process_host_root): + preserved_names = {'.venv', 'venv', 'env', '.cache', '.tmp', '.langbot'} + os.makedirs(process_host_workspace, exist_ok=True) + for name in os.listdir(process_host_workspace): + if name in preserved_names: + continue + path = os.path.join(process_host_workspace, name) + if os.path.isdir(path) and not os.path.islink(path): + shutil.rmtree(path, ignore_errors=True) + else: + # The entry may disappear between listdir and unlink if cleanup races us. + with suppress(FileNotFoundError): + os.unlink(path) + shutil.copytree( + source_path, + process_host_workspace, + symlinks=True, + dirs_exist_ok=True, + ignore=shutil.ignore_patterns( + '.git', + '__pycache__', + '.pytest_cache', + '.mypy_cache', + '.ruff_cache', + '.venv', + 'venv', + 'env', + '.cache', + '.tmp', + '.langbot', + ), + ) async def _cleanup_staged_workspace(self) -> None: if not self.resolve_host_path(): @@ -343,23 +390,25 @@ class BoxStdioSessionRuntime: @staticmethod def detect_install_command(host_path: str, workspace_path: str = '/workspace') -> str | None: workspace_kind = classify_python_workspace(host_path) - quoted_workspace_path = shlex.quote(workspace_path) - if workspace_kind == 'package': - return ( - 'mkdir -p /opt/_lb_src' - f' && tar -C {quoted_workspace_path}' - ' --exclude=.venv --exclude=.git --exclude=__pycache__' - ' --exclude=node_modules --exclude=.tox --exclude=.nox' - ' --exclude="*.egg-info" --exclude=.uv-cache' - ' -cf - .' - ' | tar -C /opt/_lb_src -xf -' - ' && pip install --no-cache-dir /opt/_lb_src' - ' && rm -rf /opt/_lb_src' - ) - if workspace_kind == 'requirements': - return f'pip install --no-cache-dir -r {quoted_workspace_path}/requirements.txt' + if workspace_kind in {'package', 'requirements'}: + return wrap_python_command_with_env('python -c "pass"', mount_path=workspace_path).rstrip() return None + @staticmethod + def _wrap_process_payload_with_python_env(payload: dict[str, Any], workspace_path: str) -> dict[str, Any]: + """Start a prepared Python workspace without writing bootstrap output to MCP stdio.""" + workspace_root = workspace_path.rstrip('/') or '/workspace' + venv_dir = f'{workspace_root}/.venv' + venv_bin = f'{venv_dir}/bin' + command = ' '.join([shlex.quote(payload['command']), *[shlex.quote(arg) for arg in payload.get('args', [])]]) + wrapped = dict(payload) + wrapped['command'] = 'sh' + wrapped['args'] = [ + '-lc', + (f'export VIRTUAL_ENV={shlex.quote(venv_dir)}; export PATH={shlex.quote(venv_bin)}:$PATH; exec {command}'), + ] + return wrapped + def build_box_session_payload(self, session_id: str, host_path: str | None = None) -> dict[str, Any]: workspace = self._build_workspace() workspace.session_id = session_id diff --git a/src/langbot/pkg/provider/tools/loaders/native.py b/src/langbot/pkg/provider/tools/loaders/native.py index 833900491..7f5ee4226 100644 --- a/src/langbot/pkg/provider/tools/loaders/native.py +++ b/src/langbot/pkg/provider/tools/loaders/native.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import json import os @@ -8,6 +9,7 @@ from langbot_plugin.api.entities.events import pipeline_query from .. import loader from ..errors import ToolNotFoundError +from .availability import is_box_backend_available from . import skill as skill_loader EXEC_TOOL_NAME = 'exec' @@ -22,6 +24,15 @@ _ALL_TOOL_NAMES = {EXEC_TOOL_NAME, READ_TOOL_NAME, WRITE_TOOL_NAME, EDIT_TOOL_NA # Skip these dirs during grep walk to avoid noise _SKIP_DIRS = {'.git', 'node_modules', '__pycache__', '.venv', 'venv', '.tox', 'dist', 'build'} +_DEFAULT_READ_MAX_LINES = 2000 +_MAX_READ_MAX_LINES = 10000 +_DEFAULT_TOOL_RESULT_MAX_BYTES = 50 * 1024 +_BOX_FILE_SCRIPT_MAX_BYTES = 2048 +_GLOB_MAX_MATCHES = 100 +_GREP_MAX_MATCHES = 200 +_GREP_MAX_FILES = 5000 +_GREP_MAX_LINE_CHARS = 500 + class NativeToolLoader(loader.ToolLoader): def __init__(self, ap): @@ -43,18 +54,7 @@ class NativeToolLoader(loader.ToolLoader): async def _check_backend_available(self) -> bool: """Check if the box backend is truly available (not just the runtime).""" - box_service = getattr(self.ap, 'box_service', None) - if box_service is None: - return False - if not getattr(box_service, 'available', False): - return False - # Check if backend is truly available via get_status - try: - status = await box_service.get_status() - backend_info = status.get('backend', {}) - return backend_info.get('available', False) - except Exception: - return False + return await is_box_backend_available(self.ap) async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]: if not self._is_sandbox_available(): @@ -139,6 +139,7 @@ class NativeToolLoader(loader.ToolLoader): # via execute_tool. Skills are mounted at /workspace/.skills/{name}/ # via extra_mounts built by BoxService. result = await self.ap.box_service.execute_tool(parameters, query) + result = self._normalize_exec_result(result) if selected_skill is not None: self._refresh_skill_from_disk(selected_skill) @@ -227,34 +228,121 @@ class NativeToolLoader(loader.ToolLoader): except Exception: return {'ok': False, 'error': stdout or 'Box file operation returned no result'} - async def _read_workspace_via_box(self, path: str, query: pipeline_query.Query) -> dict: + async def _read_workspace_via_box(self, path: str, parameters: dict, query: pipeline_query.Query) -> dict: + offset = self._positive_int(parameters.get('offset'), default=1) + byte_offset = self._non_negative_int(parameters.get('byte_offset'), default=0) + max_lines = self._positive_int( + parameters.get('limit'), + default=_DEFAULT_READ_MAX_LINES, + max_value=_MAX_READ_MAX_LINES, + ) + # Box file fallback returns through exec stdout, which is already capped + # by BoxService. Keep this payload small enough to remain valid JSON. + max_bytes = min( + self._positive_int(parameters.get('max_bytes'), default=_DEFAULT_TOOL_RESULT_MAX_BYTES), + _BOX_FILE_SCRIPT_MAX_BYTES, + ) + encoding = self._read_encoding(parameters) script = f""" -import json, os +import base64, json, os path = {json.dumps(path)} +offset = {offset} +byte_offset = {byte_offset} +max_lines = {max_lines} +max_bytes = {max_bytes} +encoding = {json.dumps(encoding)} if not path.startswith('/workspace'): print(json.dumps({{'ok': False, 'error': 'Path must be under /workspace.'}})) elif not os.path.exists(path): print(json.dumps({{'ok': False, 'error': f'File not found: {{path}}'}})) elif os.path.isdir(path): - print(json.dumps({{'ok': True, 'content': '\\n'.join(sorted(os.listdir(path))), 'is_directory': True}})) + entries = sorted(os.listdir(path)) + content = '\\n'.join(entries) + print(json.dumps({{'ok': True, 'content': content, 'is_directory': True, 'total': len(entries), 'truncated': False}})) +elif encoding == 'base64': + size_bytes = os.path.getsize(path) + with open(path, 'rb') as f: + f.seek(byte_offset) + data = f.read(max_bytes + 1) + chunk = data[:max_bytes] + has_more = len(data) > max_bytes + print(json.dumps({{ + 'ok': True, + 'content': base64.b64encode(chunk).decode('ascii'), + 'encoding': 'base64', + 'byte_offset': byte_offset, + 'length': len(chunk), + 'size_bytes': size_bytes, + 'has_more': has_more, + 'next_byte_offset': byte_offset + len(chunk) if has_more else None, + 'max_bytes': max_bytes, + }})) else: + lines = [] + output_bytes = 0 + end_line = offset - 1 + truncated = False + next_offset = None with open(path, 'r', encoding='utf-8', errors='replace') as f: - print(json.dumps({{'ok': True, 'content': f.read()}})) + for line_number, line in enumerate(f, 1): + if line_number < offset: + continue + if len(lines) >= max_lines: + truncated = True + next_offset = line_number + break + line_bytes = len(line.encode('utf-8')) + if output_bytes + line_bytes > max_bytes: + truncated = True + next_offset = line_number + break + lines.append(line.rstrip('\\n')) + output_bytes += line_bytes + end_line = line_number + print(json.dumps({{ + 'ok': True, + 'content': '\\n'.join(lines), + 'truncated': truncated, + 'start_line': offset, + 'end_line': end_line, + 'next_offset': next_offset, + 'max_lines': max_lines, + 'max_bytes': max_bytes, + }})) """.strip() return await self._run_workspace_file_script(script, query) - async def _write_workspace_via_box(self, path: str, content: str, query: pipeline_query.Query) -> dict: + async def _write_workspace_via_box( + self, + path: str, + content: str, + parameters: dict, + query: pipeline_query.Query, + ) -> dict: + encoding, mode = self._write_options(parameters) script = f""" -import json, os +import base64, json, os path = {json.dumps(path)} content = {json.dumps(content)} +encoding = {json.dumps(encoding)} +mode = {json.dumps(mode)} if not path.startswith('/workspace'): print(json.dumps({{'ok': False, 'error': 'Path must be under /workspace.'}})) else: os.makedirs(os.path.dirname(path) or '/workspace', exist_ok=True) - with open(path, 'w', encoding='utf-8') as f: - f.write(content) - print(json.dumps({{'ok': True, 'path': path}})) + if encoding == 'base64': + try: + data = base64.b64decode(content, validate=True) + except Exception as exc: + print(json.dumps({{'ok': False, 'error': f'invalid base64 content: {{exc}}'}})) + else: + with open(path, 'ab' if mode == 'append' else 'wb') as f: + f.write(data) + print(json.dumps({{'ok': True, 'path': path}})) + else: + with open(path, 'a' if mode == 'append' else 'w', encoding='utf-8') as f: + f.write(content) + print(json.dumps({{'ok': True, 'path': path}})) """.strip() return await self._run_workspace_file_script(script, query) @@ -307,12 +395,27 @@ else: if not any(part in skip_dirs for part in item.parts) ] hits.sort(key=lambda item: item.stat().st_mtime if item.exists() else 0, reverse=True) - shown = hits[:100] + shown = hits[:{_GLOB_MAX_MATCHES}] matches = [] + output_bytes = 0 + truncated_by_bytes = False for item in shown: rel = os.path.relpath(str(item), path) - matches.append(os.path.join(path, rel).replace(os.sep, '/')) - print(json.dumps({{'ok': True, 'matches': matches, 'total': len(hits), 'truncated': len(hits) > 100}})) + sandbox_path = os.path.join(path, rel).replace(os.sep, '/') + entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if matches else 0) + if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}: + truncated_by_bytes = True + break + matches.append(sandbox_path) + output_bytes += entry_bytes + print(json.dumps({{ + 'ok': True, + 'matches': matches, + 'preview': '\\n'.join(matches), + 'total': len(hits), + 'truncated': len(hits) > len(matches) or truncated_by_bytes, + 'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if len(hits) > len(matches) else None), + }})) """.strip() return await self._run_workspace_file_script(script, query) @@ -350,29 +453,54 @@ else: continue if item.is_file(): files.append(item) - if len(files) >= 5000: + if len(files) >= {_GREP_MAX_FILES}: break matches = [] + output_bytes = 0 + truncated_by = None for fp in files: try: - text = fp.read_text(errors='ignore') + handle = fp.open('r', encoding='utf-8', errors='ignore') except OSError: continue - for lineno, line in enumerate(text.splitlines(), 1): - if regex.search(line): - if base.is_file(): - file_path = path - else: - rel = os.path.relpath(str(fp), path) - file_path = os.path.join(path, rel).replace(os.sep, '/') - matches.append({{'file': file_path, 'line': lineno, 'content': line.rstrip()}}) - if len(matches) >= 200: - break - if len(matches) >= 200: + with handle: + for lineno, line in enumerate(handle, 1): + if regex.search(line): + if base.is_file(): + file_path = path + else: + rel = os.path.relpath(str(fp), path) + file_path = os.path.join(path, rel).replace(os.sep, '/') + content = line.rstrip() + line_truncated = False + if len(content) > {_GREP_MAX_LINE_CHARS}: + content = content[:{_GREP_MAX_LINE_CHARS}] + '... [truncated]' + line_truncated = True + entry = {{'file': file_path, 'line': lineno, 'content': content}} + entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1 + if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}: + truncated_by = 'bytes' + break + if line_truncated and truncated_by is None: + truncated_by = 'line' + matches.append(entry) + output_bytes += entry_bytes + if len(matches) >= {_GREP_MAX_MATCHES}: + truncated_by = truncated_by or 'matches' + break + if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}: + break + if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}: break - print(json.dumps({{'ok': True, 'matches': matches, 'total': len(matches), 'truncated': len(matches) >= 200}})) + print(json.dumps({{ + 'ok': True, + 'matches': matches, + 'total': len(matches), + 'truncated': truncated_by is not None, + 'truncated_by': truncated_by, + }})) """.strip() return await self._run_workspace_file_script(script, query) @@ -387,14 +515,20 @@ else: ) if skill_request is not None and hasattr(self.ap.box_service, 'read_skill_file'): selected_skill, relative = skill_request + host_path = self._resolve_skill_host_path(selected_skill, relative) + if host_path and os.path.exists(host_path): + if os.path.isdir(host_path): + return self._build_directory_result(os.listdir(host_path)) + return self._read_text_file_preview(host_path, parameters) + try: result = await self.ap.box_service.read_skill_file(selected_skill['name'], relative) - return {'ok': True, 'content': result.get('content', '')} + return self._build_read_result_from_text(str(result.get('content', '')), parameters) except Exception: try: result = await self.ap.box_service.list_skill_files(selected_skill['name'], relative) entries = [entry['name'] for entry in result.get('entries', [])] - return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True} + return self._build_directory_result(entries) except Exception as exc: return {'ok': False, 'error': str(exc)} @@ -405,20 +539,19 @@ else: include_activated=True, ) if self._should_use_box_workspace_files(selected_skill): - return await self._read_workspace_via_box(path, query) + return await self._read_workspace_via_box(path, parameters, query) if not os.path.exists(host_path): return {'ok': False, 'error': f'File not found: {path}'} if os.path.isdir(host_path): entries = os.listdir(host_path) - return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True} - with open(host_path, 'r', errors='replace') as f: - content = f.read() - return {'ok': True, 'content': content} + return self._build_directory_result(entries) + return self._read_text_file_preview(host_path, parameters) async def _invoke_write(self, parameters: dict, query: pipeline_query.Query) -> dict: path = parameters['path'] content = parameters['content'] self.ap.logger.info(f'write tool invoked: query_id={query.query_id} path={path} length={len(content)}') + encoding, _mode = self._write_options(parameters) skill_request = self._resolve_skill_relative_path( query, path, @@ -426,6 +559,8 @@ else: include_activated=True, ) if skill_request is not None and hasattr(self.ap.box_service, 'write_skill_file'): + if encoding != 'text': + return {'ok': False, 'error': 'base64 writes to skill packages are not supported.'} selected_skill, relative = skill_request await self.ap.box_service.write_skill_file(selected_skill['name'], relative, content) await self.ap.skill_mgr.reload_skills() @@ -438,10 +573,12 @@ else: include_activated=True, ) if self._should_use_box_workspace_files(selected_skill): - return await self._write_workspace_via_box(path, content, query) + return await self._write_workspace_via_box(path, content, parameters, query) os.makedirs(os.path.dirname(host_path), exist_ok=True) - with open(host_path, 'w', encoding='utf-8') as f: - f.write(content) + try: + self._write_host_file(host_path, content, parameters) + except ValueError as exc: + return {'ok': False, 'error': str(exc)} self._refresh_skill_from_disk(selected_skill) return {'ok': True, 'path': path} @@ -584,6 +721,40 @@ else: 'type': 'string', 'description': 'Absolute path to the file (must be under /workspace).', }, + 'offset': { + 'type': 'integer', + 'description': '1-indexed line number to start reading from. Defaults to 1.', + 'default': 1, + 'minimum': 1, + }, + 'limit': { + 'type': 'integer', + 'description': f'Maximum number of lines to return. Defaults to {_DEFAULT_READ_MAX_LINES}.', + 'default': _DEFAULT_READ_MAX_LINES, + 'minimum': 1, + 'maximum': _MAX_READ_MAX_LINES, + }, + 'max_bytes': { + 'type': 'integer', + 'description': ( + f'Maximum bytes of file content to return. Defaults to {_DEFAULT_TOOL_RESULT_MAX_BYTES}.' + ), + 'default': _DEFAULT_TOOL_RESULT_MAX_BYTES, + 'minimum': 1, + 'maximum': _DEFAULT_TOOL_RESULT_MAX_BYTES, + }, + 'encoding': { + 'type': 'string', + 'description': 'Return text by default, or base64 for binary byte-range reads.', + 'enum': ['text', 'base64'], + 'default': 'text', + }, + 'byte_offset': { + 'type': 'integer', + 'description': '0-indexed byte offset used when encoding is base64. Defaults to 0.', + 'default': 0, + 'minimum': 0, + }, }, 'required': ['path'], 'additionalProperties': False, @@ -609,7 +780,19 @@ else: }, 'content': { 'type': 'string', - 'description': 'Content to write to the file.', + 'description': 'Text content, or base64 content when encoding is base64.', + }, + 'encoding': { + 'type': 'string', + 'description': 'Write content as text by default, or decode it from base64 for binary files.', + 'enum': ['text', 'base64'], + 'default': 'text', + }, + 'mode': { + 'type': 'string', + 'description': 'Overwrite the file by default, or append to it.', + 'enum': ['overwrite', 'append'], + 'default': 'overwrite', }, }, 'required': ['path', 'content'], @@ -740,22 +923,30 @@ else: hits.sort(key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True) total = len(hits) - shown = hits[:100] + shown = hits[:_GLOB_MAX_MATCHES] # Convert back to sandbox paths sandbox_paths = [] + output_bytes = 0 + truncated_by_bytes = False for h in shown: rel = os.path.relpath(str(h), host_path) sandbox_path = os.path.join(path, rel) + entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if sandbox_paths else 0) + if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES: + truncated_by_bytes = True + break sandbox_paths.append(sandbox_path) + output_bytes += entry_bytes - result_lines = sandbox_paths - result = '\n'.join(result_lines) - - if total > 100: - result += f'\n... ({total} matches, showing first 100)' - - return {'ok': True, 'matches': result_lines, 'total': total, 'truncated': total > 100} + return { + 'ok': True, + 'matches': sandbox_paths, + 'preview': '\n'.join(sandbox_paths), + 'total': total, + 'truncated': total > len(sandbox_paths) or truncated_by_bytes, + 'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if total > len(sandbox_paths) else None), + } async def _invoke_grep(self, parameters: dict, query: pipeline_query.Query) -> dict: pattern = parameters['pattern'] @@ -791,32 +982,46 @@ else: files = self._grep_walk(base, include) matches = [] + output_bytes = 0 + truncated_by = None for fp in files: try: - text = fp.read_text(errors='ignore') + handle = fp.open('r', encoding='utf-8', errors='ignore') except OSError: continue - for lineno, line in enumerate(text.splitlines(), 1): - if regex.search(line): - rel = os.path.relpath(str(fp), host_path) - sandbox_path = os.path.join(path, rel) - matches.append( - { + with handle: + for lineno, line in enumerate(handle, 1): + if regex.search(line): + rel = os.path.relpath(str(fp), host_path) + sandbox_path = os.path.join(path, rel) + content, line_truncated = self._truncate_grep_line(line.rstrip()) + entry = { 'file': sandbox_path, 'line': lineno, - 'content': line.rstrip(), + 'content': content, } - ) - if len(matches) >= 200: - break - if len(matches) >= 200: + entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1 + if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES: + truncated_by = 'bytes' + break + if line_truncated and truncated_by is None: + truncated_by = 'line' + matches.append(entry) + output_bytes += entry_bytes + if len(matches) >= _GREP_MAX_MATCHES: + truncated_by = truncated_by or 'matches' + break + if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES: + break + if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES: break return { 'ok': True, 'matches': matches, 'total': len(matches), - 'truncated': len(matches) >= 200, + 'truncated': truncated_by is not None, + 'truncated_by': truncated_by, } @staticmethod @@ -828,10 +1033,266 @@ else: continue if item.is_file(): results.append(item) - if len(results) >= 5000: + if len(results) >= _GREP_MAX_FILES: break return results + @staticmethod + def _resolve_skill_host_path(selected_skill: dict, relative: str) -> str | None: + package_root = str(selected_skill.get('package_root', '') or '').strip() + if not package_root: + return None + + host_root = os.path.realpath(package_root) + host_path = os.path.realpath(os.path.join(host_root, relative)) + if not (host_path == host_root or host_path.startswith(host_root + os.sep)): + raise ValueError('Path escapes the skill package boundary.') + return host_path + + def _normalize_exec_result(self, result: dict) -> dict: + normalized = dict(result) + stdout = str(normalized.get('stdout') or '') + stderr = str(normalized.get('stderr') or '') + stdout, stdout_capped = self._truncate_text_to_bytes_with_flag(stdout, _DEFAULT_TOOL_RESULT_MAX_BYTES) + stderr, stderr_capped = self._truncate_text_to_bytes_with_flag(stderr, _DEFAULT_TOOL_RESULT_MAX_BYTES) + normalized['stdout'] = stdout + normalized['stderr'] = stderr + normalized['stdout_truncated'] = bool(normalized.get('stdout_truncated') or stdout_capped) + normalized['stderr_truncated'] = bool(normalized.get('stderr_truncated') or stderr_capped) + + if stdout and stderr: + preview_raw = f'stdout:\n{stdout}\n\nstderr:\n{stderr}' + else: + preview_raw = stdout or stderr + preview, preview_capped = self._truncate_text_to_bytes_with_flag(preview_raw, _DEFAULT_TOOL_RESULT_MAX_BYTES) + normalized['preview'] = preview + normalized['truncated'] = bool( + normalized['stdout_truncated'] or normalized['stderr_truncated'] or preview_capped + ) + if preview_capped and not normalized.get('truncated_by'): + normalized['truncated_by'] = 'bytes' + return normalized + + def _build_directory_result(self, entries: list[str]) -> dict: + sorted_entries = sorted(str(entry) for entry in entries) + content = '\n'.join(sorted_entries) + preview = self._truncate_text_to_bytes(content, _DEFAULT_TOOL_RESULT_MAX_BYTES) + truncated = preview != content + return { + 'ok': True, + 'content': preview, + 'is_directory': True, + 'total': len(sorted_entries), + 'truncated': truncated, + 'truncated_by': 'bytes' if truncated else None, + } + + def _read_text_file_preview(self, host_path: str, parameters: dict) -> dict: + if self._read_encoding(parameters) == 'base64': + return self._read_binary_file_chunk(host_path, parameters) + + offset = self._positive_int(parameters.get('offset'), default=1) + max_lines = self._positive_int( + parameters.get('limit'), + default=_DEFAULT_READ_MAX_LINES, + max_value=_MAX_READ_MAX_LINES, + ) + max_bytes = self._positive_int( + parameters.get('max_bytes'), + default=_DEFAULT_TOOL_RESULT_MAX_BYTES, + max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES, + ) + lines: list[str] = [] + output_bytes = 0 + end_line = offset - 1 + truncated = False + truncated_by: str | None = None + next_offset: int | None = None + + with open(host_path, 'r', encoding='utf-8', errors='replace') as f: + for line_number, line in enumerate(f, 1): + if line_number < offset: + continue + if len(lines) >= max_lines: + truncated = True + truncated_by = 'lines' + next_offset = line_number + break + + line_bytes = len(line.encode('utf-8')) + if output_bytes + line_bytes > max_bytes: + truncated = True + truncated_by = 'bytes' + next_offset = line_number + break + + lines.append(line.rstrip('\n')) + output_bytes += line_bytes + end_line = line_number + + if not lines and truncated_by == 'bytes': + content = ( + f'[Line {next_offset or offset} exceeds the {self._format_size(max_bytes)} read limit. ' + 'Use exec with a byte-range command for this line, or read a different offset.]' + ) + else: + content = '\n'.join(lines) + + return { + 'ok': True, + 'content': content, + 'truncated': truncated, + 'truncated_by': truncated_by, + 'start_line': offset, + 'end_line': end_line, + 'next_offset': next_offset, + 'max_lines': max_lines, + 'max_bytes': max_bytes, + } + + def _read_binary_file_chunk(self, host_path: str, parameters: dict) -> dict: + byte_offset = self._non_negative_int(parameters.get('byte_offset'), default=0) + max_bytes = self._positive_int( + parameters.get('max_bytes'), + default=_DEFAULT_TOOL_RESULT_MAX_BYTES, + max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES, + ) + size_bytes = os.path.getsize(host_path) + with open(host_path, 'rb') as f: + f.seek(byte_offset) + data = f.read(max_bytes + 1) + chunk = data[:max_bytes] + has_more = len(data) > max_bytes + return { + 'ok': True, + 'content': base64.b64encode(chunk).decode('ascii'), + 'encoding': 'base64', + 'byte_offset': byte_offset, + 'length': len(chunk), + 'size_bytes': size_bytes, + 'has_more': has_more, + 'next_byte_offset': byte_offset + len(chunk) if has_more else None, + 'max_bytes': max_bytes, + } + + def _write_host_file(self, host_path: str, content: str, parameters: dict) -> None: + encoding, mode = self._write_options(parameters) + if encoding == 'base64': + try: + data = base64.b64decode(content, validate=True) + except Exception as exc: + raise ValueError(f'invalid base64 content: {exc}') from exc + with open(host_path, 'ab' if mode == 'append' else 'wb') as f: + f.write(data) + return + with open(host_path, 'a' if mode == 'append' else 'w', encoding='utf-8') as f: + f.write(content) + + @staticmethod + def _read_encoding(parameters: dict) -> str: + return 'base64' if parameters.get('encoding') == 'base64' else 'text' + + @staticmethod + def _write_options(parameters: dict) -> tuple[str, str]: + encoding = 'base64' if parameters.get('encoding') == 'base64' else 'text' + mode = 'append' if parameters.get('mode') == 'append' else 'overwrite' + return encoding, mode + + def _build_read_result_from_text(self, content: str, parameters: dict) -> dict: + offset = self._positive_int(parameters.get('offset'), default=1) + max_lines = self._positive_int( + parameters.get('limit'), + default=_DEFAULT_READ_MAX_LINES, + max_value=_MAX_READ_MAX_LINES, + ) + max_bytes = self._positive_int( + parameters.get('max_bytes'), + default=_DEFAULT_TOOL_RESULT_MAX_BYTES, + max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES, + ) + all_lines = content.splitlines() + start_index = offset - 1 + if start_index >= len(all_lines) and all_lines: + return {'ok': False, 'error': f'Offset {offset} is beyond end of file ({len(all_lines)} lines total)'} + output_lines: list[str] = [] + output_bytes = 0 + truncated = False + truncated_by: str | None = None + next_offset: int | None = None + for index, line in enumerate(all_lines[start_index:], start_index + 1): + if len(output_lines) >= max_lines: + truncated = True + truncated_by = 'lines' + next_offset = index + break + line_bytes = len(line.encode('utf-8')) + (1 if output_lines else 0) + if output_bytes + line_bytes > max_bytes: + truncated = True + truncated_by = 'bytes' + next_offset = index + break + output_lines.append(line) + output_bytes += line_bytes + + end_line = offset + len(output_lines) - 1 + return { + 'ok': True, + 'content': '\n'.join(output_lines), + 'truncated': truncated, + 'truncated_by': truncated_by, + 'start_line': offset, + 'end_line': end_line, + 'next_offset': next_offset, + 'max_lines': max_lines, + 'max_bytes': max_bytes, + } + + @staticmethod + def _positive_int(value, *, default: int, max_value: int | None = None) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + if parsed <= 0: + parsed = default + if max_value is not None: + parsed = min(parsed, max_value) + return parsed + + @staticmethod + def _non_negative_int(value, *, default: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + return parsed if parsed >= 0 else default + + @staticmethod + def _truncate_grep_line(line: str) -> tuple[str, bool]: + if len(line) <= _GREP_MAX_LINE_CHARS: + return line, False + return f'{line[:_GREP_MAX_LINE_CHARS]}... [truncated]', True + + @staticmethod + def _truncate_text_to_bytes(text: str, max_bytes: int) -> str: + return NativeToolLoader._truncate_text_to_bytes_with_flag(text, max_bytes)[0] + + @staticmethod + def _truncate_text_to_bytes_with_flag(text: str, max_bytes: int) -> tuple[str, bool]: + data = text.encode('utf-8') + if len(data) <= max_bytes: + return text, False + truncated = data[:max_bytes] + while truncated and (truncated[-1] & 0xC0) == 0x80: + truncated = truncated[:-1] + return truncated.decode('utf-8', errors='ignore'), True + + @staticmethod + def _format_size(bytes_count: int) -> str: + if bytes_count < 1024: + return f'{bytes_count}B' + return f'{bytes_count / 1024:.1f}KB' + def _summarize_parameters(self, parameters: dict) -> dict: summary = dict(parameters) cmd = str(summary.get('command', '')).strip() diff --git a/src/langbot/pkg/provider/tools/loaders/skill.py b/src/langbot/pkg/provider/tools/loaders/skill.py index 9df94fd28..b62f3e7d5 100644 --- a/src/langbot/pkg/provider/tools/loaders/skill.py +++ b/src/langbot/pkg/provider/tools/loaders/skill.py @@ -72,6 +72,45 @@ def register_activated_skill(query: pipeline_query.Query, skill_data: dict) -> N activated[skill_name] = skill_data +def normalize_skill_names(value: typing.Any) -> list[str]: + """Return a de-duplicated list of non-empty skill names.""" + if not isinstance(value, list): + return [] + + names: list[str] = [] + for item in value: + skill_name = str(item or '').strip() + if skill_name and skill_name not in names: + names.append(skill_name) + return names + + +def get_activated_skill_names(query: pipeline_query.Query) -> list[str]: + """Return activated skill names for callers that own persistence policy.""" + return normalize_skill_names(list(get_activated_skills(query).keys())) + + +def restore_activated_skills( + ap: app.Application, + query: pipeline_query.Query, + skill_names: typing.Any, +) -> list[str]: + """Restore caller-provided activated skill names into Query variables. + + Persistence and state scope ownership belong to higher-level flows. This + helper only rebuilds current Query state from pipeline-visible skills, so + removed or unbound skills stay unavailable to native exec/write/edit. + """ + restored: list[str] = [] + for skill_name in normalize_skill_names(skill_names): + skill_data = get_visible_skill(ap, query, skill_name) + if skill_data is None: + continue + register_activated_skill(query, skill_data) + restored.append(skill_name) + return restored + + def parse_skill_mount_path(sandbox_path: str) -> tuple[str | None, str]: normalized_path = str(sandbox_path or '/workspace').strip() or '/workspace' if normalized_path == SKILL_MOUNT_PREFIX: diff --git a/src/langbot/pkg/provider/tools/loaders/skill_authoring.py b/src/langbot/pkg/provider/tools/loaders/skill_authoring.py index 9d0fe6e9a..d53721785 100644 --- a/src/langbot/pkg/provider/tools/loaders/skill_authoring.py +++ b/src/langbot/pkg/provider/tools/loaders/skill_authoring.py @@ -6,6 +6,7 @@ import typing import langbot_plugin.api.entities.builtin.resource.tool as resource_tool from .. import loader +from .availability import is_box_backend_available # Align with Claude Code's Skill tool design: # - activate: Activate a skill via Tool Call, returns SKILL.md content @@ -45,18 +46,7 @@ class SkillToolLoader(loader.ToolLoader): async def _check_sandbox_available(self) -> bool: """Check if the box backend is truly available (not just the runtime).""" - box_service = getattr(self.ap, 'box_service', None) - if box_service is None: - return False - if not getattr(box_service, 'available', False): - return False - # Check if backend is truly available via get_status - try: - status = await box_service.get_status() - backend_info = status.get('backend', {}) - return backend_info.get('available', False) - except Exception: - return False + return await is_box_backend_available(self.ap) async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]: if not self._is_available(): @@ -92,16 +82,15 @@ class SkillToolLoader(loader.ToolLoader): if not skill_name: raise ValueError('skill_name is required') - skill_mgr = self.ap.skill_mgr - skill_data = skill_mgr.get_skill_by_name(skill_name) + from . import skill as skill_loader + + skill_data = skill_loader.get_visible_skill(self.ap, query, skill_name) if skill_data is None: - visible_skills = getattr(skill_mgr, 'skills', {}) + visible_skills = skill_loader.get_visible_skills(self.ap, query) available_names = ', '.join(sorted(visible_skills.keys())) or 'none' raise ValueError(f'Skill "{skill_name}" not found. Available skills: {available_names}') # Register activated skill for sandbox mount path resolution - from . import skill as skill_loader - skill_loader.register_activated_skill(query, skill_data) # Return SKILL.md content as Tool Result (injects into context) @@ -127,6 +116,7 @@ class SkillToolLoader(loader.ToolLoader): 'activated': True, 'skill_name': skill_name, 'mount_path': mount_path, + 'activated_skill_names': skill_loader.get_activated_skill_names(query), 'content': result_content, } @@ -201,13 +191,13 @@ class SkillToolLoader(loader.ToolLoader): return resource_tool.LLMTool( name=ACTIVATE_SKILL_TOOL_NAME, human_desc='Activate a skill', - description=self._build_activate_tool_description(), + description='Activate a pipeline-visible skill by name and return its instructions as a tool result.', parameters={ 'type': 'object', 'properties': { 'skill_name': { 'type': 'string', - 'description': 'The skill name to activate (no arguments). E.g., "pdf" or "data-analysis"', + 'description': 'The skill name to activate.', }, }, 'required': ['skill_name'], @@ -255,50 +245,3 @@ class SkillToolLoader(loader.ToolLoader): }, func=lambda parameters: parameters, ) - - def _build_activate_tool_description(self) -> str: - """Build tool description with embedded available_skills list.""" - skill_mgr = getattr(self.ap, 'skill_mgr', None) - if skill_mgr is None: - return 'Activate a skill. No skills are currently available.' - - skills = getattr(skill_mgr, 'skills', {}) - if not skills: - return 'Activate a skill. No skills are currently available.' - - # Build section - available_skills_lines = [''] - for skill_name, skill_data in sorted(skills.items()): - description = skill_data.get('description', '') - available_skills_lines.append('') - available_skills_lines.append(f'{skill_name}') - available_skills_lines.append(f'{description}') - available_skills_lines.append('') - available_skills_lines.append('') - - available_skills_block = '\n'.join(available_skills_lines) - - return f"""Activate a skill within the main conversation. - - -When users ask you to perform tasks, check if any of the available skills -below can help complete the task more effectively. Skills provide specialized -capabilities and domain knowledge. - -How to use skills: -- Invoke skills using this tool with the skill name only (no arguments) -- When you invoke a skill, you will see -The skill is activated - -- The skill's instructions will be provided in the tool result -- Examples: - - skill_name: "pdf" - invoke the pdf skill - - skill_name: "data-analysis" - invoke the data-analysis skill - -Important: -- Only use skills listed in below -- Do not invoke a skill that is already running -- To create a new skill: prepare it in /workspace, then use register_skill tool - - -{available_skills_block}""" diff --git a/tests/unit_tests/box/test_workspace.py b/tests/unit_tests/box/test_workspace.py index 809347e56..e4620ad32 100644 --- a/tests/unit_tests/box/test_workspace.py +++ b/tests/unit_tests/box/test_workspace.py @@ -54,7 +54,9 @@ def test_classify_python_workspace_detects_package_and_requirements(): def test_wrap_python_command_with_env_contains_bootstrap_and_command(): command = wrap_python_command_with_env('python script.py') - assert 'python -m venv "$_LB_VENV_DIR"' in command + assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command + assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command + assert 'kill -0 "$_LB_LOCK_OWNER"' in command assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command assert command.rstrip().endswith('python script.py') diff --git a/tests/unit_tests/provider/test_mcp_box_integration.py b/tests/unit_tests/provider/test_mcp_box_integration.py index 3e3a7a4d5..74cd2487c 100644 --- a/tests/unit_tests/provider/test_mcp_box_integration.py +++ b/tests/unit_tests/provider/test_mcp_box_integration.py @@ -180,7 +180,7 @@ class TestMCPServerBoxConfig: assert cfg.host_path is None assert cfg.host_path_mode == 'ro' assert cfg.env == {} - assert cfg.startup_timeout_sec == 120 + assert cfg.startup_timeout_sec == 300 assert cfg.cpus is None assert cfg.memory_mb is None assert cfg.pids_limit is None @@ -494,6 +494,84 @@ class TestBuildBoxProcessPayload: assert payload['args'] == ['/opt/other/server.py', '--flag'] +# ── Python Workspace Preparation ──────────────────────────────────── + + +class TestPythonWorkspacePreparation: + def test_requirements_workspace_uses_venv_bootstrap(self, mcp_module, tmp_path): + host_path = tmp_path / 'mcp-source' + host_path.mkdir() + (host_path / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8') + + command = mcp_module.BoxStdioSessionRuntime.detect_install_command( + str(host_path), + '/workspace/.mcp/u1/workspace', + ) + + assert command is not None + assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command + assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command + assert 'python -m pip install -r "/workspace/.mcp/u1/workspace/requirements.txt"' in command + assert 'pip install --no-cache-dir -r' not in command + + def test_staging_refresh_removes_stale_source_files_but_preserves_runtime_dirs(self, mcp_module, tmp_path): + source = tmp_path / 'source' + source.mkdir() + (source / 'server.py').write_text('print("new")\n', encoding='utf-8') + (source / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8') + (source / '.env').write_text('TOKEN=new\n', encoding='utf-8') + + process_root = tmp_path / 'shared' / '.mcp' / 'u1' + workspace = process_root / 'workspace' + (workspace / '.venv' / 'bin').mkdir(parents=True) + (workspace / '.venv' / 'bin' / 'python').write_text('', encoding='utf-8') + (workspace / '.langbot').mkdir() + (workspace / '.langbot' / 'python-env.lock').mkdir() + (workspace / '.env').write_text('TOKEN=old\n', encoding='utf-8') + (workspace / 'server.py').write_text('print("old")\n', encoding='utf-8') + (workspace / 'removed.py').write_text('stale\n', encoding='utf-8') + (workspace / 'removed_dir').mkdir() + (workspace / 'removed_dir' / 'old.txt').write_text('stale\n', encoding='utf-8') + + mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace)) + + assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n' + assert (workspace / 'requirements.txt').read_text(encoding='utf-8') == 'mcp==1.26.0\n' + assert (workspace / '.env').read_text(encoding='utf-8') == 'TOKEN=new\n' + assert not (workspace / 'removed.py').exists() + assert not (workspace / 'removed_dir').exists() + assert (workspace / '.venv' / 'bin' / 'python').exists() + assert (workspace / '.langbot' / 'python-env.lock').is_dir() + + def test_staging_refresh_ignores_unlink_race(self, mcp_module, tmp_path, monkeypatch): + mcp_stdio_module = sys.modules['langbot.pkg.provider.tools.loaders.mcp_stdio'] + + source = tmp_path / 'source' + source.mkdir() + (source / 'server.py').write_text('print("new")\n', encoding='utf-8') + + process_root = tmp_path / 'shared' / '.mcp' / 'u1' + workspace = process_root / 'workspace' + workspace.mkdir(parents=True) + stale_file = workspace / 'removed.py' + stale_file.write_text('stale\n', encoding='utf-8') + + real_unlink = os.unlink + + def unlink_with_race(path): + if os.fspath(path) == str(stale_file): + real_unlink(path) + raise FileNotFoundError(path) + real_unlink(path) + + monkeypatch.setattr(mcp_stdio_module.os, 'unlink', unlink_with_race) + + mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace)) + + assert not stale_file.exists() + assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n' + + # ── get_runtime_info_dict ─────────────────────────────────────────── diff --git a/tests/unit_tests/provider/test_skill_tools.py b/tests/unit_tests/provider/test_skill_tools.py index 847480c10..9db7b945e 100644 --- a/tests/unit_tests/provider/test_skill_tools.py +++ b/tests/unit_tests/provider/test_skill_tools.py @@ -193,6 +193,29 @@ class TestSkillPathHelpers: assert list(result.keys()) == ['visible'] + def test_restore_activated_skills_uses_caller_provided_names_and_visibility(self): + from langbot.pkg.provider.tools.loaders.skill import ( + ACTIVATED_SKILLS_KEY, + PIPELINE_BOUND_SKILLS_KEY, + get_activated_skill_names, + restore_activated_skills, + ) + + ap = _make_ap() + ap.skill_mgr = SimpleNamespace( + skills={ + 'visible': _make_skill_data(name='visible'), + 'hidden': _make_skill_data(name='hidden'), + } + ) + query = SimpleNamespace(variables={PIPELINE_BOUND_SKILLS_KEY: ['visible']}) + + restored = restore_activated_skills(ap, query, ['visible', 'hidden', 'visible', '']) + + assert restored == ['visible'] + assert list(query.variables[ACTIVATED_SKILLS_KEY].keys()) == ['visible'] + assert get_activated_skill_names(query) == ['visible'] + def test_resolve_virtual_skill_path_allows_visible_skill_reads(self): from langbot.pkg.provider.tools.loaders.skill import ( PIPELINE_BOUND_SKILLS_KEY, @@ -245,7 +268,8 @@ class TestSkillPathHelpers: command = wrap_skill_command_with_python_env('python scripts/run.py') - assert 'python -m venv "$_LB_VENV_DIR"' in command + assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command + assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command assert command.rstrip().endswith('python scripts/run.py') @@ -281,6 +305,7 @@ class TestSkillToolLoader: assert result['activated'] is True assert result['skill_name'] == 'demo' assert result['mount_path'] == '/workspace/.skills/demo' + assert result['activated_skill_names'] == ['demo'] assert 'Step 1' in result['content'] assert set(query.variables[ACTIVATED_SKILLS_KEY].keys()) == {'demo'} @@ -456,7 +481,9 @@ class TestNativeToolLoaderSkillPaths: SimpleNamespace(query_id='q1', variables={PIPELINE_BOUND_SKILLS_KEY: ['demo']}), ) - assert result == {'ok': True, 'content': 'demo instructions'} + assert result['ok'] is True + assert result['content'] == 'demo instructions' + assert result['truncated'] is False @pytest.mark.asyncio async def test_exec_in_activated_skill_mount_rewrites_command_and_refreshes(self): @@ -485,7 +512,7 @@ class TestNativeToolLoaderSkillPaths: query, ) - assert result == {'ok': True} + assert result['ok'] is True tool_parameters = ap.box_service.execute_tool.await_args.args[0] assert tool_parameters['command'] == 'python /workspace/.skills/demo/scripts/run.py' assert tool_parameters['workdir'] == '/workspace/.skills/demo' diff --git a/tests/unit_tests/provider/test_tool_manager_native.py b/tests/unit_tests/provider/test_tool_manager_native.py index 117a20fd3..01e044e5a 100644 --- a/tests/unit_tests/provider/test_tool_manager_native.py +++ b/tests/unit_tests/provider/test_tool_manager_native.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import os import tempfile from types import SimpleNamespace @@ -189,6 +190,78 @@ async def test_write_creates_subdirectories(): assert f.read() == 'nested' +@pytest.mark.asyncio +async def test_read_binary_file_as_base64_chunk(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + with open(os.path.join(tmpdir, 'blob.bin'), 'wb') as f: + f.write(b'\x00\x01\x02\x03\x04') + + result = await loader.invoke_tool( + 'read', + { + 'path': '/workspace/blob.bin', + 'encoding': 'base64', + 'byte_offset': 1, + 'max_bytes': 2, + }, + _make_query(), + ) + + assert result['ok'] is True + assert result['content'] == base64.b64encode(b'\x01\x02').decode('ascii') + assert result['encoding'] == 'base64' + assert result['byte_offset'] == 1 + assert result['length'] == 2 + assert result['size_bytes'] == 5 + assert result['has_more'] is True + assert result['next_byte_offset'] == 3 + + +@pytest.mark.asyncio +async def test_write_base64_file_append(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + + first = base64.b64encode(b'\x00\x01').decode('ascii') + second = base64.b64encode(b'\x02\x03').decode('ascii') + await loader.invoke_tool( + 'write', + {'path': '/workspace/blob.bin', 'content': first, 'encoding': 'base64'}, + _make_query(), + ) + result = await loader.invoke_tool( + 'write', + { + 'path': '/workspace/blob.bin', + 'content': second, + 'encoding': 'base64', + 'mode': 'append', + }, + _make_query(), + ) + + assert result['ok'] is True + with open(os.path.join(tmpdir, 'blob.bin'), 'rb') as f: + assert f.read() == b'\x00\x01\x02\x03' + + +@pytest.mark.asyncio +async def test_write_base64_rejects_invalid_content(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + + result = await loader.invoke_tool( + 'write', + {'path': '/workspace/blob.bin', 'content': 'not base64!', 'encoding': 'base64'}, + _make_query(), + ) + + assert result['ok'] is False + assert 'invalid base64' in result['error'] + assert not os.path.exists(os.path.join(tmpdir, 'blob.bin')) + + @pytest.mark.asyncio async def test_edit_replaces_unique_string(): with tempfile.TemporaryDirectory() as tmpdir: @@ -248,3 +321,135 @@ async def test_path_escape_blocked(): with pytest.raises(ValueError, match='escapes'): await loader.invoke_tool('read', {'path': '/workspace/../../etc/passwd'}, _make_query()) + + +@pytest.mark.asyncio +async def test_box_availability_helper_handles_unavailable_and_errors(): + from langbot.pkg.provider.tools.loaders.availability import is_box_backend_available + + assert await is_box_backend_available(SimpleNamespace()) is False + assert await is_box_backend_available(SimpleNamespace(box_service=SimpleNamespace(available=False))) is False + + unavailable_backend = SimpleNamespace( + available=True, + get_status=AsyncMock(return_value={'backend': {'available': False}}), + ) + assert await is_box_backend_available(SimpleNamespace(box_service=unavailable_backend)) is False + + failing_backend = SimpleNamespace( + available=True, + get_status=AsyncMock(side_effect=RuntimeError('box unavailable')), + ) + assert await is_box_backend_available(SimpleNamespace(box_service=failing_backend)) is False + + +@pytest.mark.asyncio +async def test_read_file_supports_offset_limit_and_truncation_metadata(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + with open(os.path.join(tmpdir, 'lines.txt'), 'w', encoding='utf-8') as f: + f.write('one\ntwo\nthree\nfour\n') + + result = await loader.invoke_tool( + 'read', + {'path': '/workspace/lines.txt', 'offset': 2, 'limit': 2}, + _make_query(), + ) + + assert result == { + 'ok': True, + 'content': 'two\nthree', + 'truncated': True, + 'truncated_by': 'lines', + 'start_line': 2, + 'end_line': 3, + 'next_offset': 4, + 'max_lines': 2, + 'max_bytes': 50 * 1024, + } + + +@pytest.mark.asyncio +async def test_read_file_handles_line_larger_than_byte_limit(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + with open(os.path.join(tmpdir, 'long-line.txt'), 'w', encoding='utf-8') as f: + f.write('abcdef\n') + + result = await loader.invoke_tool( + 'read', + {'path': '/workspace/long-line.txt', 'max_bytes': 3}, + _make_query(), + ) + + assert result['ok'] is True + assert result['truncated'] is True + assert result['truncated_by'] == 'bytes' + assert result['next_offset'] == 1 + assert 'exceeds the 3B read limit' in result['content'] + + +@pytest.mark.asyncio +async def test_exec_result_is_capped_and_exposes_preview_metadata(): + with tempfile.TemporaryDirectory() as tmpdir: + box_service = SimpleNamespace( + available=True, + default_workspace=tmpdir, + execute_tool=AsyncMock( + return_value={ + 'ok': True, + 'stdout': 'a' * 60000, + 'stderr': 'b' * 60000, + 'exit_code': 0, + } + ), + ) + loader = NativeToolLoader(SimpleNamespace(box_service=box_service, logger=Mock())) + + result = await loader.invoke_tool('exec', {'command': 'python -V'}, _make_query()) + + assert result['ok'] is True + assert len(result['stdout'].encode('utf-8')) == 50 * 1024 + assert len(result['stderr'].encode('utf-8')) == 50 * 1024 + assert len(result['preview'].encode('utf-8')) == 50 * 1024 + assert result['stdout_truncated'] is True + assert result['stderr_truncated'] is True + assert result['truncated'] is True + assert result['truncated_by'] == 'bytes' + + +@pytest.mark.asyncio +async def test_glob_caps_match_count_and_returns_preview(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + for index in range(105): + with open(os.path.join(tmpdir, f'file-{index:03d}.txt'), 'w', encoding='utf-8') as f: + f.write(str(index)) + + result = await loader.invoke_tool('glob', {'path': '/workspace', 'pattern': '*.txt'}, _make_query()) + + assert result['ok'] is True + assert result['total'] == 105 + assert len(result['matches']) == 100 + assert result['preview'] == '\n'.join(result['matches']) + assert result['truncated'] is True + assert result['truncated_by'] == 'matches' + + +@pytest.mark.asyncio +async def test_grep_reports_invalid_regex_and_truncates_long_matching_lines(): + with tempfile.TemporaryDirectory() as tmpdir: + loader, _ = _make_loader_with_workspace(tmpdir) + with open(os.path.join(tmpdir, 'data.txt'), 'w', encoding='utf-8') as f: + f.write('needle ' + ('x' * 600) + '\n') + + invalid = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': '['}, _make_query()) + result = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': 'needle'}, _make_query()) + + assert invalid['ok'] is False + assert 'Invalid regex' in invalid['error'] + assert result['ok'] is True + assert result['truncated'] is True + assert result['truncated_by'] == 'line' + assert result['matches'][0]['file'] == '/workspace/data.txt' + assert result['matches'][0]['content'].endswith('... [truncated]')