Compare commits

...

14 Commits

Author SHA1 Message Date
RockChinQ 3b3b09331a test(box): cover inbound/outbound attachment helpers; fix ruff format
- ruff format localagent.py (CI ruff format --check was failing)
- add unit tests for ResponseWrapper outbound-attachment helpers (wrapper.py 78%->98%)
- add unit tests for LocalAgentRunner._inject_inbound_attachments
- add unit tests for WebSocketAdapter._process_image_components (0%->covered)

Lifts PR patch coverage from 68.97% to ~88% (>75% target).
2026-06-17 22:08:42 -04:00
RockChinQ 75e5af26d0 feat(box): support voice/file attachment round-trip end-to-end
Extends the bidirectional attachment transfer to audio and arbitrary files
through the real webchat UI, and fixes the model-payload errors that
non-image attachments triggered.

- platform(websocket_adapter): resolve Voice/File component storage keys to
  base64 (previously only Image), so audio/documents reach the sandbox inbox.
- web(debug-dialog): accept audio/* and any file in the uploader (was
  image-only), classify by mimetype, upload Voice/File via the documents
  endpoint, and render non-image staged attachments as a chip.
- provider(litellmchat): drop non-image file parts (file_base64 / file_url)
  when building the OpenAI/LiteLLM payload. These come from Voice/File
  attachments — including ones replayed from conversation history — and the
  agent reads their bytes from the sandbox, not the model. Without this the
  provider rejects the request: 'invalid content type=file_base64'.
- provider(localagent): also strip those parts from the current user message
  alongside the sandbox-path note (model-facing clarity; the requester is the
  real safety net for history).
- tests: cover the requester strip/keep behavior (file dropped, image kept and
  reshaped to image_url, mixed history, plain-string content).
2026-06-17 21:57:09 -04:00
RockChinQ 22c0a18bea feat(box): bidirectional attachment transfer for sandbox
Materialize inbound attachments into the sandbox workspace so agents can
process user-sent files, and collect agent-produced files from the outbox
to attach them back to the reply.

- box(service): add materialize_inbound_attachments / collect_outbound
  attachments. Prefer direct host-filesystem read/write on the bind-mounted
  workspace (no size limit), falling back to chunked exec only for
  non-shared backends (e2b/remote). Clear per-query inbox/outbox dirs at
  turn start to avoid query_id-reuse collisions.
- provider(localagent): inject inbound attachment descriptors into the
  sandbox and append a system note telling the agent the inbox/outbox paths.
- pipeline(wrapper): collect outbox files on the final stream chunk and
  append them as attachment components to the response chain.
- web(debug-dialog): render File components with a download link when
  base64/url is present; add base64/path fields to the File entity.
- tests: cover inbound/outbound, large-file transfer without truncation,
  and stale-dir clearing (86 passing).
2026-06-17 19:01:03 -04:00
huanghuoguoguo b3c6de2072 [codex] cover frontend CRUD smoke flows (#2253)
* test: cover frontend CRUD smoke flows

* test: add bot CRUD smoke coverage

* test: add bot/pipeline advanced flows and cross-resource tests

- Bot enable/disable toggle with state persistence
- Bot detail tab switching (Configuration, Logs, Sessions)
- Bot form dirty state and save button behavior
- Bot name validation error display
- Pipeline tab switching (Configuration, Dashboard)
- Pipeline form dirty state
- Pipeline name validation error display
- Cross-resource flow: create pipeline then bind to bot
- Empty states for bots, pipelines, knowledge bases, MCP servers
2026-06-16 21:34:17 +08:00
RockChinQ 4e45886647 style(web): show Models above API Integration in main sidebar footer 2026-06-16 06:04:59 -04:00
RockChinQ f592656680 refactor(web): unify settings panel layouts with shared toolbar/body
- Add PanelToolbar/PanelBody primitives so all four settings tabs share
  the same top-toolbar + scrollable-body rhythm under the unified header.
- API panel: drop the heavy gray shadowed TabsList; move the create
  action into the toolbar next to the tabs, lighten per-tab hints.
- Storage panel: reuse PanelToolbar for the generated-at/refresh bar.
- Account panel: wrap content in PanelBody for consistent padding.
- Models panel: keep the pinned LangBot Models (Space) card at the very
  top, above the add-custom-provider row (intentional pin), using
  PanelBody instead of a top toolbar.
2026-06-16 06:02:20 -04:00
RockChinQ e9db858dcc feat(web): unified header for settings dialog, shorter sidebar labels
- Add a shared section header (icon + title + description) with right
  padding so the dialog close X no longer overlaps panel content, and
  every tab now shares the same top layout for a consistent look.
- Shorten inner sidebar nav labels (Models/API/Storage/Account) via new
  settingsDialog.nav.* i18n keys across all 8 locales.
- Add common.apiIntegrationDescription and account.settingsDescription
  for the new header.
2026-06-16 05:50:44 -04:00
RockChinQ 2d6faf9d5e refactor(web): drop legacy ModelsDialog, use unified SettingsDialog everywhere
The model-selector in dynamic forms (pipeline / knowledge base settings)
still opened the old standalone ModelsDialog. Point it at the unified
SettingsDialog (section pinned to models) and delete the now-unused
ModelsDialog wrapper so only the new dialog remains.
2026-06-16 05:41:58 -04:00
RockChinQ d4699547e9 i18n(web): localize Bots/Pipelines sidebar titles for es/th/vi
es-ES pipelines, th-TH bots+pipelines and vi-VN pipelines were left in
English in the sidebar. Translate them: es Flujos, th บอท/ไปป์ไลน์,
vi Quy trình.
2026-06-16 05:27:10 -04:00
RockChinQ 716d7aca94 fix(web): fixed-height settings dialog, narrower sidebar
Pin the dialog to a fixed 80vh (cap 800px) so switching sections no
longer resizes it; panels scroll their own content internally. Override
the SidebarProvider wrapper's default h-svh with h-full so both columns
fill the dialog height. Narrow the inner settings sidebar to w-44.
2026-06-16 05:22:42 -04:00
RockChinQ b3c00fe6da fix(web): use fixed height for settings dialog instead of 80vh
Avoid the dialog stretching to fill tall viewports (large empty space).
Pin to 620px with max-h-[85vh] fallback and narrow width to 52rem.
2026-06-16 05:18:14 -04:00
RockChinQ f4a6edf7ec refactor(web): unify settings dialogs into single dialog with sidebar
Merge API integration, model settings, account settings and storage
analysis into one SettingsDialog with a shadcn inner sidebar for
section switching. Preserve existing ?action= query-param deep links
(showModelSettings / showAccountSettings / showApiIntegrationSettings /
showStorageAnalysis) by mapping each to a section. Extract reusable
panels and keep ModelsDialog as a thin wrapper for the dynamic-form
model picker.
2026-06-16 05:06:06 -04:00
huanghuoguoguo f390980d0a test: format test suite (#2252) 2026-06-16 11:22:29 +08:00
huanghuoguoguo 1ae5aacc00 test: add frontend smoke and backend e2e CI (#2251) 2026-06-16 11:09:55 +08:00
144 changed files with 6034 additions and 2797 deletions
+46
View File
@@ -0,0 +1,46 @@
name: Frontend Tests
on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
push:
branches:
- master
- develop
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
jobs:
playwright-smoke:
name: Playwright Smoke
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 8.9.2
- name: Install dependencies
working-directory: web
run: pnpm install --frozen-lockfile
- name: Install Playwright browsers
working-directory: web
run: pnpm exec playwright install --with-deps chromium
- name: Run Playwright smoke tests
working-directory: web
run: pnpm test:e2e
+1 -1
View File
@@ -29,7 +29,7 @@ jobs:
run: uv sync --dev
- name: Run ruff check
run: uv run ruff check src
run: uv run ruff check src/langbot/ tests/ --output-format=concise
- name: Run ruff format
run: uv run ruff format src --check
+62 -1
View File
@@ -84,6 +84,67 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
e2e:
name: E2E Startup Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run E2E startup tests
run: uv run pytest tests/e2e -q --tb=short
- name: E2E Test Summary
if: always()
run: |
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
box-integration:
name: Box Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Check Docker runtime
run: docker info
- name: Run Box integration tests
run: uv run pytest tests/integration_tests -q --tb=short
- name: Box Integration Test Summary
if: always()
run: |
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage:
name: Coverage Gate
runs-on: ubuntu-latest
@@ -129,4 +190,4 @@ jobs:
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
+422
View File
@@ -335,6 +335,428 @@ class BoxService:
return await self.execute_spec_payload(spec_payload, query)
# ── Attachment passthrough (inbound / outbound) ──────────────────
#
# IM/webchat attachments (images, voices, files) reach the LLM as
# multimodal content, but historically never landed on the sandbox
# filesystem, so the agent's exec/read/write tools could not operate on
# them. Conversely, files the agent produced inside the sandbox were
# never surfaced back to the user. These two helpers close both gaps:
#
# inbound : message_chain attachments -> /workspace/inbox/<query_id>/
# outbound : /workspace/outbox/<query_id>/ -> reply MessageChain
#
# Transfer prefers DIRECT HOST FILESYSTEM access to the bind-mounted
# workspace (default_workspace on the host maps to /workspace inside the
# container), which has no size limit. This covers the local docker /
# nsjail / stdio backends. For backends where the workspace is NOT visible
# on the LangBot host (E2B, an external remote runtime.endpoint), it falls
# back to a base64-through-exec round-trip. The exec channel can only move
# small files reliably — the docker backend passes the command as a single
# argv (ARG_MAX) and exec stdout is truncated by output_limit_chars — so
# the host path is strongly preferred and used whenever available.
INBOX_MOUNT_DIR = '/workspace/inbox'
OUTBOX_MOUNT_DIR = '/workspace/outbox'
INBOX_SUBDIR = 'inbox'
OUTBOX_SUBDIR = 'outbox'
# Hard cap on a single attachment. The HTTP upload endpoints already cap
# uploads at 10MiB; keep parity.
_ATTACHMENT_MAX_BYTES = 10 * _MIB
# Conservative cap for the exec FALLBACK path only (ARG_MAX / stdout
# truncation). The host-filesystem path has no such limit.
_EXEC_FALLBACK_MAX_BYTES = 256 * 1024
def _host_query_dir(self, subdir: str, query_id) -> str | None:
"""Host path for ``/workspace/<subdir>/<query_id>`` when LangBot can
access the bind-mounted workspace directly, else ``None``.
``default_workspace`` is the host directory bind-mounted to
``/workspace`` for the local docker/nsjail backends and shared
outright in stdio mode, so a file written there by LangBot is visible
to the sandbox (and vice-versa). It is ``None`` / not a local dir for
E2B and remote runtimes, where we must fall back to the exec channel.
"""
root = self.default_workspace
if not root or not os.path.isdir(root):
return None
return os.path.join(root, subdir, str(query_id))
@staticmethod
def _sanitize_attachment_name(name: str, fallback: str) -> str:
"""Reduce an arbitrary attachment name to a safe basename.
Strips directory separators and parent refs so a crafted file name
can never escape the inbox/outbox directory.
"""
base = os.path.basename(str(name or '').replace('\\', '/').strip())
base = base.lstrip('.') or ''
# Drop anything that is not a conservative filename charset.
cleaned = ''.join(c for c in base if c.isalnum() or c in ('.', '_', '-', ' ')).strip()
cleaned = cleaned.replace(' ', '_')
return cleaned or fallback
@staticmethod
async def _component_to_bytes(component) -> tuple[bytes, str] | None:
"""Best-effort extraction of (bytes, mime) from a platform component.
Handles base64, http(s) url and local path sources. Returns None when
no payload can be resolved.
"""
import base64 as _b64
b64 = getattr(component, 'base64', None)
if b64:
data = b64
mime = 'application/octet-stream'
if isinstance(data, str) and data.startswith('data:'):
split_index = data.find(';base64,')
if split_index != -1:
mime = data[5:split_index]
data = data[split_index + 8 :]
try:
return _b64.b64decode(data), mime
except Exception:
return None
url = getattr(component, 'url', None)
if url:
try:
import httpx
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(url)
resp.raise_for_status()
return resp.content, resp.headers.get('Content-Type', 'application/octet-stream')
except Exception:
return None
path = getattr(component, 'path', None)
if path:
try:
import aiofiles
async with aiofiles.open(path, 'rb') as f:
return await f.read(), 'application/octet-stream'
except Exception:
return None
return None
async def _write_files_into_sandbox(
self,
query: pipeline_query.Query,
subdir: str,
target_mount_dir: str,
files: list[tuple[str, bytes]],
) -> list[str]:
"""Write *files* (name, bytes) into the per-query directory.
Prefers a direct host-filesystem write to the bind-mounted workspace
(no size limit). Falls back to a base64-through-exec round-trip only
when the workspace is not visible on the LangBot host (E2B / remote).
Returns the list of in-sandbox paths actually written.
"""
if not files:
return []
host_dir = self._host_query_dir(subdir, query.query_id)
if host_dir is not None:
return await asyncio.to_thread(self._write_files_host, host_dir, target_mount_dir, files)
return await self._write_files_via_exec(query, target_mount_dir, files)
def _write_files_host(
self,
host_dir: str,
target_mount_dir: str,
files: list[tuple[str, bytes]],
) -> list[str]:
"""Write attachments straight onto the bind-mounted host directory.
Recreates the per-query directory from scratch so a reused query_id
(the webchat session uses small sequential ids) never inherits stale
files from an earlier turn.
"""
import shutil
shutil.rmtree(host_dir, ignore_errors=True)
os.makedirs(host_dir, exist_ok=True)
written: list[str] = []
for name, data in files:
with open(os.path.join(host_dir, name), 'wb') as fh:
fh.write(data)
written.append(f'{target_mount_dir}/{name}')
return written
async def _write_files_via_exec(
self,
query: pipeline_query.Query,
target_dir: str,
files: list[tuple[str, bytes]],
) -> list[str]:
"""Fallback: ship files into the sandbox over the exec channel.
Only used for backends without host-filesystem access (E2B / remote).
Each file is base64-decoded inside the sandbox. Files larger than the
conservative exec cap are skipped (ARG_MAX / stdout limits).
"""
import base64 as _b64
import json as _json
manifest = []
for name, data in files:
if len(data) > self._EXEC_FALLBACK_MAX_BYTES:
self.ap.logger.warning(
f'Attachment "{name}" ({len(data)} bytes) exceeds the exec-channel '
f'fallback limit ({self._EXEC_FALLBACK_MAX_BYTES} bytes); skipping. '
f'Configure a host-shared workspace to transfer large files.'
)
continue
manifest.append({'name': name, 'b64': _b64.b64encode(data).decode('ascii')})
if not manifest:
return []
manifest_b64 = _b64.b64encode(_json.dumps(manifest).encode('utf-8')).decode('ascii')
script = (
'import base64, json, os, shutil\n'
f'target = {target_dir!r}\n'
'shutil.rmtree(target, ignore_errors=True)\n'
'os.makedirs(target, exist_ok=True)\n'
f'manifest = json.loads(base64.b64decode({manifest_b64!r}))\n'
'written = []\n'
'for item in manifest:\n'
" p = os.path.join(target, item['name'])\n"
" with open(p, 'wb') as f:\n"
" f.write(base64.b64decode(item['b64']))\n"
' written.append(p)\n'
'print(json.dumps(written))\n'
)
result = await self.execute_tool(
{'command': f"python3 - <<'LBPY'\n{script}\nLBPY", 'timeout_sec': 120},
query,
)
if not result.get('ok'):
self.ap.logger.warning(
f'Failed to write inbound attachments into sandbox via exec: '
f'query_id={query.query_id} stderr={result.get("stderr", "")[:200]}'
)
return []
try:
return _json.loads(str(result.get('stdout') or '').strip().splitlines()[-1])
except Exception:
return []
async def materialize_inbound_attachments(self, query: pipeline_query.Query) -> list[dict]:
"""Persist message-chain attachments into the sandbox inbox.
Returns a list of ``{path, name, type, size}`` describing what was
written, so the runner can tell the LLM the exact in-sandbox paths.
Returns ``[]`` when sandbox is unavailable or there are no attachments.
"""
if not self._available:
return []
import langbot_plugin.api.entities.builtin.platform.message as platform_message
message_chain = getattr(query, 'message_chain', None)
if not message_chain:
return []
type_map = [
(platform_message.Image, 'Image', 'image', 'png'),
(platform_message.Voice, 'Voice', 'voice', 'wav'),
(platform_message.File, 'File', 'file', 'bin'),
]
pending: list[tuple[str, bytes]] = []
descriptors: list[dict] = []
index = 0
for component in message_chain:
matched = None
for cls, kind, prefix, default_ext in type_map:
if isinstance(component, cls):
matched = (kind, prefix, default_ext)
break
if matched is None:
continue
kind, prefix, default_ext = matched
payload = await self._component_to_bytes(component)
if payload is None:
continue
data, _mime = payload
if not data or len(data) > self._ATTACHMENT_MAX_BYTES:
continue
index += 1
raw_name = getattr(component, 'name', None) or f'{prefix}_{index}.{default_ext}'
safe_name = self._sanitize_attachment_name(raw_name, f'{prefix}_{index}.{default_ext}')
pending.append((safe_name, data))
descriptors.append(
{
'name': safe_name,
'type': kind,
'size': len(data),
}
)
if not pending:
return []
target_dir = f'{self.INBOX_MOUNT_DIR}/{query.query_id}'
written = await self._write_files_into_sandbox(query, self.INBOX_SUBDIR, target_dir, pending)
written_basenames = {os.path.basename(p) for p in written}
result: list[dict] = []
for desc in descriptors:
if desc['name'] in written_basenames:
desc['path'] = f'{target_dir}/{desc["name"]}'
result.append(desc)
if result:
self.ap.logger.info(
f'Materialized {len(result)} inbound attachment(s) into sandbox: '
f'query_id={query.query_id} dir={target_dir}'
)
return result
async def collect_outbound_attachments(self, query: pipeline_query.Query) -> list[dict]:
"""Collect files the agent produced in the sandbox outbox.
Reads ``/workspace/outbox/<query_id>/`` (recursively) — directly from
the bind-mounted host directory when available (no size limit), else
via the exec channel — returns a list of ``{type, name, base64}``
ready to become platform message components, then clears the outbox so
a later turn in the same session does not re-send stale files. Returns
``[]`` when nothing was produced.
"""
if not self._available:
return []
host_dir = self._host_query_dir(self.OUTBOX_SUBDIR, query.query_id)
if host_dir is not None:
entries = await asyncio.to_thread(self._read_outbox_host, host_dir)
else:
entries = await self._read_outbox_via_exec(query)
attachments = self._classify_outbound_entries(entries)
if attachments:
await self._clear_outbox(query, host_dir)
self.ap.logger.info(
f'Collected {len(attachments)} outbound attachment(s) from sandbox: query_id={query.query_id}'
)
return attachments
def _read_outbox_host(self, host_dir: str) -> list[dict]:
"""Read outbox files straight off the bind-mounted host directory."""
import base64 as _b64
entries: list[dict] = []
if not os.path.isdir(host_dir):
return entries
for root, _dirs, names in os.walk(host_dir):
for name in sorted(names):
path = os.path.join(root, name)
try:
if os.path.getsize(path) > self._ATTACHMENT_MAX_BYTES:
continue
with open(path, 'rb') as fh:
data = fh.read()
except OSError:
continue
rel = os.path.relpath(path, host_dir)
entries.append({'name': rel, 'b64': _b64.b64encode(data).decode('ascii')})
return entries
async def _read_outbox_via_exec(self, query: pipeline_query.Query) -> list[dict]:
"""Fallback: read the outbox over the exec channel (E2B / remote).
Note: exec stdout is truncated by ``output_limit_chars``, so this path
only reliably transfers small files. The host path is preferred.
"""
import json as _json
target_dir = f'{self.OUTBOX_MOUNT_DIR}/{query.query_id}'
max_bytes = self._EXEC_FALLBACK_MAX_BYTES
script = (
'import base64, json, os\n'
f'target = {target_dir!r}\n'
f'max_bytes = {max_bytes}\n'
'out = []\n'
'if os.path.isdir(target):\n'
' for root, _dirs, names in os.walk(target):\n'
' for n in sorted(names):\n'
' p = os.path.join(root, n)\n'
' try:\n'
' if os.path.getsize(p) > max_bytes:\n'
' continue\n'
" with open(p, 'rb') as f:\n"
' data = f.read()\n'
' except OSError:\n'
' continue\n'
' rel = os.path.relpath(p, target)\n'
" out.append({'name': rel, 'b64': base64.b64encode(data).decode('ascii')})\n"
'print(json.dumps(out))\n'
)
result = await self.execute_tool(
{'command': f"python3 - <<'LBPY'\n{script}\nLBPY", 'timeout_sec': 120},
query,
)
if not result.get('ok'):
return []
try:
return _json.loads(str(result.get('stdout') or '').strip().splitlines()[-1])
except Exception:
return []
async def _clear_outbox(self, query: pipeline_query.Query, host_dir: str | None) -> None:
"""Empty the per-query outbox after collection (host or exec)."""
if host_dir is not None:
import shutil
def _clear():
shutil.rmtree(host_dir, ignore_errors=True)
os.makedirs(host_dir, exist_ok=True)
await asyncio.to_thread(_clear)
return
target_dir = f'{self.OUTBOX_MOUNT_DIR}/{query.query_id}'
await self.execute_tool(
{'command': f'rm -rf {target_dir} && mkdir -p {target_dir}', 'timeout_sec': 30},
query,
)
@staticmethod
def _classify_outbound_entries(entries: list[dict]) -> list[dict]:
"""Classify outbox files into Image/Voice/File component descriptors."""
image_exts = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'bmp'}
voice_exts = {'wav', 'mp3', 'silk', 'amr', 'ogg', 'm4a', 'aac'}
mime_by_ext = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'gif': 'image/gif',
'webp': 'image/webp',
'bmp': 'image/bmp',
}
attachments: list[dict] = []
for entry in entries or []:
name = str(entry.get('name', '') or '')
b64 = entry.get('b64')
if not name or not b64:
continue
ext = name.rsplit('.', 1)[-1].lower() if '.' in name else ''
base_name = os.path.basename(name)
if ext in image_exts:
mime = mime_by_ext.get(ext, 'image/png')
attachments.append({'type': 'Image', 'name': base_name, 'base64': f'data:{mime};base64,{b64}'})
elif ext in voice_exts:
attachments.append({'type': 'Voice', 'name': base_name, 'base64': f'data:audio/{ext};base64,{b64}'})
else:
attachments.append({'type': 'File', 'name': base_name, 'base64': b64})
return attachments
async def shutdown(self):
await self.client.shutdown()
+54 -3
View File
@@ -7,6 +7,7 @@ from .. import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.events as events
@@ -23,6 +24,50 @@ class ResponseWrapper(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
pass
def _is_final_assistant_message(self, result) -> bool:
"""Whether *result* is the agent's final, tool-call-free answer.
Intermediate streaming chunks and tool-call rounds must NOT trigger
outbound attachment collection — only the terminal assistant message.
"""
if getattr(result, 'role', None) != 'assistant':
return False
if result.tool_calls:
return False
if isinstance(result, provider_message.MessageChunk):
return bool(result.is_final)
return True
async def _append_outbound_attachments(
self,
query: pipeline_query.Query,
message_chain: platform_message.MessageChain,
) -> None:
"""Collect sandbox outbox files and append them to *message_chain*.
Runs at most once per query (guarded by a query variable) and never
raises into the pipeline — attachment delivery is best-effort.
"""
if query.variables.get('_sandbox_outbound_collected'):
return
box_service = getattr(self.ap, 'box_service', None)
if box_service is None or not getattr(box_service, 'available', False):
return
query.variables['_sandbox_outbound_collected'] = True
try:
attachments = await box_service.collect_outbound_attachments(query)
except Exception as e:
self.ap.logger.warning(f'Outbound attachment collection failed: {e}')
return
for att in attachments:
att_type = att.get('type')
if att_type == 'Image':
message_chain.append(platform_message.Image(base64=att['base64']))
elif att_type == 'Voice':
message_chain.append(platform_message.Voice(base64=att['base64']))
else:
message_chain.append(platform_message.File(name=att.get('name', 'file'), base64=att['base64']))
async def process(
self,
query: pipeline_query.Query,
@@ -83,10 +128,16 @@ class ResponseWrapper(stage.PipelineStage):
)
else:
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
reply_chain = event_ctx.event.reply_message_chain
else:
query.resp_message_chain.append(result.get_content_platform_message_chain())
reply_chain = result.get_content_platform_message_chain()
# Attach files the agent produced in the sandbox
# outbox, but only on the terminal assistant message.
if self._is_final_assistant_message(result):
await self._append_outbound_attachments(query, reply_chain)
query.resp_message_chain.append(reply_chain)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -312,12 +312,18 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
async def _process_image_components(self, message_chain_obj: list):
"""
处理消息链中的图片和文件组件,将path转换为base64
处理消息链中的图片、语音和文件组件,将 path 转换为 base64
Image / Voice / File components uploaded from the web client carry a
storage key in ``path``. Resolve it to a base64 data URI so downstream
stages (multimodal LLM input and the Box sandbox inbox) have a usable
payload, then drop the now-consumed storage object.
Args:
message_chain_obj: 消息链对象列表
"""
import base64
import mimetypes
storage_mgr = self.ap.storage_mgr
@@ -325,31 +331,33 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
comp_type = component.get('type', '')
comp_path = component.get('path', '')
if not comp_path:
if not comp_path or comp_type not in ('Image', 'Voice', 'File'):
continue
if comp_type == 'Image':
try:
file_content = await storage_mgr.storage_provider.load(comp_path)
base64_str = base64.b64encode(file_content).decode('utf-8')
try:
file_content = await storage_mgr.storage_provider.load(comp_path)
base64_str = base64.b64encode(file_content).decode('utf-8')
file_key = comp_path
if file_key.lower().endswith(('.jpg', '.jpeg')):
lowered = comp_path.lower()
if comp_type == 'Image':
if lowered.endswith(('.jpg', '.jpeg')):
mime_type = 'image/jpeg'
elif file_key.lower().endswith('.png'):
mime_type = 'image/png'
elif file_key.lower().endswith('.gif'):
elif lowered.endswith('.gif'):
mime_type = 'image/gif'
elif file_key.lower().endswith('.webp'):
elif lowered.endswith('.webp'):
mime_type = 'image/webp'
else:
mime_type = 'image/png'
elif comp_type == 'Voice':
mime_type = mimetypes.guess_type(comp_path)[0] or 'audio/wav'
else: # File
mime_type = mimetypes.guess_type(comp_path)[0] or 'application/octet-stream'
component['base64'] = f'data:{mime_type};base64,{base64_str}'
await storage_mgr.storage_provider.delete(comp_path)
component['path'] = ''
except Exception as e:
await self.logger.error(f'Failed to load image file {comp_path}: {e}')
component['base64'] = f'data:{mime_type};base64,{base64_str}'
await storage_mgr.storage_provider.delete(comp_path)
component['path'] = ''
except Exception as e:
await self.logger.error(f'Failed to load {comp_type} file {comp_path}: {e}')
async def handle_websocket_message(
self,
@@ -216,11 +216,22 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
content = msg_dict.get('content')
if isinstance(content, list):
converted_parts = []
for part in content:
if isinstance(part, dict) and part.get('type') == 'image_base64':
part['image_url'] = {'url': part['image_base64']}
part['type'] = 'image_url'
del part['image_base64']
# OpenAI-compatible chat models reject non-image file parts
# (audio/document base64 or url). These originate from Voice /
# File attachments — including ones replayed from conversation
# history — and the agent already accesses their bytes via the
# sandbox. Drop them from the model payload to avoid
# "Invalid user message ... invalid content type=file_base64".
if isinstance(part, dict) and part.get('type') in ('file_base64', 'file_url'):
continue
converted_parts.append(part)
msg_dict['content'] = converted_parts
req_messages.append(msg_dict)
@@ -104,6 +104,68 @@ class _StreamAccumulator:
class LocalAgentRunner(runner.RequestRunner):
"""Local agent request runner"""
async def _inject_inbound_attachments(
self,
query: pipeline_query.Query,
user_message: provider_message.Message,
) -> None:
"""Persist inbound attachments into the sandbox and tell the model.
No-op when the box service is unavailable or there are no attachments.
On success, appends an extra text ContentElement to the user message
listing the in-sandbox paths and the outbox convention, and stashes the
descriptors in ``query.variables['_sandbox_inbound_attachments']``.
"""
box_service = getattr(self.ap, 'box_service', None)
if box_service is None or not getattr(box_service, 'available', False):
return
try:
attachments = await box_service.materialize_inbound_attachments(query)
except Exception as e: # never break the chat turn over attachment IO
self.ap.logger.warning(f'Inbound attachment materialization failed: {e}')
return
if not attachments:
return
query.variables['_sandbox_inbound_attachments'] = attachments
lines = [
'The user sent attachments. They have been saved into the sandbox and are '
'available to the exec/read/write tools at these paths:'
]
for att in attachments:
lines.append(f'- {att["type"]}: {att["path"]} ({att["size"]} bytes)')
outbox_dir = f'{box_service.OUTBOX_MOUNT_DIR}/{query.query_id}'
lines.append(
'If you produce any file (image, audio, document, etc.) that should be sent '
f'back to the user, write it into {outbox_dir}/ (create the directory if '
'needed). Every file placed there will be delivered to the user automatically.'
)
note = '\n'.join(lines)
# Voice/File attachments are now available to the agent via the sandbox
# (exec/read/write tools). Their raw bytes must NOT be forwarded to the
# chat model as multimodal content: providers reject non-image file
# parts ("Invalid user message ... ensure all user messages are valid
# OpenAI chat completion messages"). Strip those content elements and
# rely on the sandbox-path note instead. Images are kept so vision
# models can still see them.
_model_unsafe_types = {'file_base64', 'file_url'}
if isinstance(user_message.content, list):
user_message.content = [
ce for ce in user_message.content if getattr(ce, 'type', None) not in _model_unsafe_types
]
if isinstance(user_message.content, str):
user_message.content = [
provider_message.ContentElement.from_text(user_message.content),
provider_message.ContentElement.from_text(note),
]
elif isinstance(user_message.content, list):
user_message.content.append(provider_message.ContentElement.from_text(note))
else:
user_message.content = [provider_message.ContentElement.from_text(note)]
def _build_request_messages(
self,
query: pipeline_query.Query,
@@ -232,6 +294,12 @@ class LocalAgentRunner(runner.RequestRunner):
user_message = copy.deepcopy(query.user_message)
# Materialize inbound attachments (images / voices / files) into the
# sandbox so the agent's exec/read/write tools can operate on the real
# bytes — not just the multimodal copy the model sees. The exact
# in-sandbox paths are announced to the model as a system note.
await self._inject_inbound_attachments(query, user_message)
user_message_text = ''
if isinstance(user_message.content, str):
+59 -3
View File
@@ -1,6 +1,7 @@
# LangBot Test Suite
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
This directory contains the LangBot backend test suite, including unit tests,
integration tests, startup E2E tests, and container-backed Box runtime tests.
## Quality Gate Layers
@@ -10,10 +11,15 @@ LangBot uses a layered quality gate system for developers and CI:
|-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
**Note**: PostgreSQL migration tests and slow tests are NOT in local default
gates. They run in separate CI workflows. Frontend Playwright tests live under
`web/tests/e2e` and are documented in `web/README.md`.
### Developer Workflow
@@ -28,6 +34,9 @@ make test-all-local
bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min
uv run --python 3.12 pytest tests/e2e -q --tb=short
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
cd web && pnpm test:e2e
```
### Coverage Baseline
@@ -70,6 +79,12 @@ tests/
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── e2e/ # Real LangBot startup E2E tests
│ ├── conftest.py
│ ├── test_startup.py
│ └── utils/
├── integration_tests/ # Container-backed integration tests
│ └── box/ # Box runtime and MCP process tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
@@ -303,6 +318,44 @@ These tests:
- Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys
### Running backend E2E startup tests
Backend E2E tests start a real LangBot process with a generated minimal
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
They do not require provider keys or external services.
```bash
uv run --python 3.12 pytest tests/e2e -q --tb=short
```
These tests verify startup orchestration, migrations, API route registration,
and the minimal no-LLM startup path. The E2E process manager disables ambient
proxy variables for subprocess startup and uses direct localhost HTTP clients,
so local proxy settings should not affect the health checks.
### Running Box integration tests
Box integration tests exercise the real sandbox runtime path, including command
execution, session persistence, managed process WebSocket attachment, and
cleanup behavior.
```bash
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
```
These tests require a working Docker or Podman runtime. In CI, the dedicated
Box integration job checks Docker availability before running the tests.
### Running frontend E2E tests
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
and mock the LangBot backend and Space APIs, so no backend process is required.
```bash
cd web
pnpm test:e2e
```
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
@@ -320,6 +373,9 @@ Tests are automatically run on:
- Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
they exercise process/container behavior instead of pure Python compatibility.
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
## Adding New Tests
@@ -406,4 +462,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
- [ ] Add E2E tests
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis
- [ ] Add property-based testing with Hypothesis
+2 -2
View File
@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0) as client:
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client:
yield client
@pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db'
return e2e_tmpdir / 'data' / 'langbot.db'
+11 -6
View File
@@ -38,12 +38,13 @@ class TestStartupFlow:
# System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, e2e_db_path):
def test_database_initialized(self, langbot_process, e2e_db_path):
"""Verify SQLite database was created and initialized."""
assert e2e_db_path.exists()
# Database should have some tables after migration
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
@@ -74,10 +75,13 @@ class TestStartupFlow:
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint."""
# First startup may allow initial setup
response = e2e_client.post('/api/v1/user/auth', json={
'username': 'admin',
'password': 'admin',
})
response = e2e_client.post(
'/api/v1/user/auth',
json={
'user': 'admin',
'password': 'admin',
},
)
# Response could be:
# - 200 if auth succeeds
@@ -94,9 +98,10 @@ class TestStartupStages:
# If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, e2e_db_path):
def test_migrations_applied(self, langbot_process, e2e_db_path):
"""Verify database migrations were applied."""
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
+1 -1
View File
@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
for path in directories.values():
path.mkdir(parents=True, exist_ok=True)
return directories
return directories
+20 -3
View File
@@ -44,6 +44,17 @@ class LangBotProcess:
# Prepare environment
env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src')
for proxy_key in (
'HTTP_PROXY',
'HTTPS_PROXY',
'ALL_PROXY',
'http_proxy',
'https_proxy',
'all_proxy',
):
env.pop(proxy_key, None)
env['NO_PROXY'] = '127.0.0.1,localhost'
env['no_proxy'] = '127.0.0.1,localhost'
# Set API port via environment variable
env['API__PORT'] = str(self.port)
@@ -79,9 +90,11 @@ precision = 2
f.write(coveragerc_content)
cmd = [
'coverage', 'run',
'coverage',
'run',
'--rcfile=' + str(coveragerc_path),
'-m', 'langbot',
'-m',
'langbot',
]
else:
cmd = ['uv', 'run', 'python', '-m', 'langbot']
@@ -113,6 +126,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0,
follow_redirects=False,
trust_env=False,
)
if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}')
@@ -185,6 +200,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0,
follow_redirects=False,
trust_env=False,
)
return r.status_code == 200
except Exception:
@@ -201,4 +218,4 @@ def find_project_root() -> Path:
return parent
# Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build')
return Path('/home/glwuy/langbot-app/LangBot-test-build')
+36 -36
View File
@@ -58,45 +58,45 @@ from tests.factories.platform import (
__all__ = [
# App
"FakeApp",
"fake_app",
'FakeApp',
'fake_app',
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
'text_chain',
'group_text_chain',
'mention_chain',
'image_chain',
# Message events
"friend_message_event",
"group_message_event",
'friend_message_event',
'group_message_event',
# Mock adapters
"mock_adapter",
'mock_adapter',
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
"image_query",
"file_query",
"unsupported_query",
"voice_query",
"at_all_query",
"query_with_session",
"query_with_config",
'text_query',
'group_text_query',
'private_text_query',
'command_query',
'mention_query',
'empty_query',
'image_query',
'file_query',
'unsupported_query',
'voice_query',
'at_all_query',
'query_with_session',
'query_with_config',
# Provider
"FakeProvider",
"fake_provider",
"fake_provider_pong",
"fake_provider_timeout",
"fake_provider_auth_error",
"fake_provider_rate_limit",
"fake_provider_malformed",
"fake_model",
'FakeProvider',
'fake_provider',
'fake_provider_pong',
'fake_provider_timeout',
'fake_provider_auth_error',
'fake_provider_rate_limit',
'fake_provider_malformed',
'fake_model',
# Platform
"FakePlatform",
"fake_platform",
"fake_platform_with_streaming",
"fake_platform_with_failure",
"mock_platform_adapter",
]
'FakePlatform',
'fake_platform',
'fake_platform_with_streaming',
'fake_platform_with_failure',
'mock_platform_adapter',
]
+83 -77
View File
@@ -30,32 +30,36 @@ def _next_query_id() -> int:
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
def text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
return platform_message.MessageChain(
[
platform_message.At(target=target),
platform_message.Plain(text=f' {text}'),
]
)
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
@@ -66,13 +70,15 @@ def image_chain(
def command_chain(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=f'{prefix}{command}'),
]
)
# ============== Message Event Factories ==============
@@ -81,7 +87,7 @@ def command_chain(
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
@@ -90,7 +96,7 @@ def friend_message_event(
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -100,9 +106,9 @@ def friend_message_event(
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
@@ -117,7 +123,7 @@ def group_message_event(
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -152,36 +158,36 @@ def _base_query(
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'query_id': query_id,
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'message_chain': message_chain,
'message_event': message_event,
'adapter': adapter,
'pipeline_uuid': 'test-pipeline-uuid',
'bot_uuid': 'test-bot-uuid',
'pipeline_config': {
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
'session': None,
'prompt': None,
'messages': [],
'user_message': None,
'use_funcs': [],
'use_llm_model_uuid': None,
'variables': {},
'resp_messages': [],
'resp_message_chain': None,
'current_stage_name': None,
}
# Apply overrides
@@ -192,7 +198,7 @@ def _base_query(
def text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -212,7 +218,7 @@ def text_query(
def private_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -221,7 +227,7 @@ def private_text_query(
def group_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -242,8 +248,8 @@ def group_text_query(
def command_query(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -263,7 +269,7 @@ def command_query(
def mention_query(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
@@ -301,8 +307,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -322,9 +328,9 @@ def image_query(
def file_query(
url: str = "https://example.com/document.pdf",
name: str = "document.pdf",
text: str = "",
url: str = 'https://example.com/document.pdf',
name: str = 'document.pdf',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -348,8 +354,8 @@ def file_query(
def unsupported_query(
unsupported_type: str = "CustomComponent",
text: str = "",
unsupported_type: str = 'CustomComponent',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -358,7 +364,7 @@ def unsupported_query(
if text:
components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}'))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
@@ -374,7 +380,7 @@ def unsupported_query(
def query_with_session(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None,
**overrides,
@@ -389,7 +395,7 @@ def query_with_session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -398,7 +404,7 @@ def query_with_session(
def query_with_config(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None,
**overrides,
@@ -410,22 +416,22 @@ def query_with_config(
"""
if pipeline_config is None:
pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
}
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query(
url: str = "https://example.com/audio.mp3",
url: str = 'https://example.com/audio.mp3',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -448,7 +454,7 @@ def voice_query(
def at_all_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -456,7 +462,7 @@ def at_all_query(
"""Create a group query with @All mention."""
components = [
platform_message.AtAll(),
platform_message.Plain(text=f" {text}"),
platform_message.Plain(text=f' {text}'),
]
chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id)
@@ -469,4 +475,4 @@ def at_all_query(
sender_id=sender_id,
adapter=adapter,
**overrides,
)
)
+52 -46
View File
@@ -33,7 +33,7 @@ class FakePlatform:
def __init__(
self,
*,
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
raise_error: Exception = None,
):
@@ -48,16 +48,16 @@ class FakePlatform:
# Registered listeners
self._listeners: dict = {}
def raises(self, error: Exception) -> "FakePlatform":
def raises(self, error: Exception) -> 'FakePlatform':
"""Configure platform to raise an error on send."""
self._raise_error = error
return self
def send_failure(self) -> "FakePlatform":
def send_failure(self) -> 'FakePlatform':
"""Configure platform to simulate send failure."""
return self.raises(Exception("Platform send failure"))
return self.raises(Exception('Platform send failure'))
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
def supports_streaming(self, supported: bool = True) -> 'FakePlatform':
"""Configure whether streaming output is supported."""
self._stream_output_supported = supported
return self
@@ -89,7 +89,7 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event."""
sender = platform_entities.Friend(
@@ -97,11 +97,13 @@ class FakePlatform:
nickname=nickname,
remark=None,
)
chain = platform_message.MessageChain([
platform_message.Plain(text=text),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -111,9 +113,9 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
mention_bot: bool = False,
) -> platform_events.GroupMessage:
"""Create an inbound group message event.
@@ -142,12 +144,12 @@ class FakePlatform:
components = []
if mention_bot:
components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=' '))
components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -155,8 +157,8 @@ class FakePlatform:
def create_image_message(
self,
url: str = "https://example.com/image.png",
text: str = "",
url: str = 'https://example.com/image.png',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
is_group: bool = False,
group_id: typing.Union[int, str] = 99999,
@@ -169,12 +171,12 @@ class FakePlatform:
chain = platform_message.MessageChain(components)
if is_group:
return self.create_group_message("", sender_id, group_id=group_id)
return self.create_group_message('', sender_id, group_id=group_id)
# Replace chain
else:
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -192,12 +194,14 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "send",
"target_type": target_type,
"target_id": target_id,
"message": message,
})
self._outbound_messages.append(
{
'type': 'send',
'target_type': target_type,
'target_id': target_id,
'message': message,
}
)
async def reply_message(
self,
@@ -209,13 +213,15 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "reply",
"source_type": message_source.type,
"source": message_source,
"message": message,
"quote_origin": quote_origin,
})
self._outbound_messages.append(
{
'type': 'reply',
'source_type': message_source.type,
'source': message_source,
'message': message,
'quote_origin': quote_origin,
}
)
async def reply_message_chunk(
self,
@@ -229,15 +235,17 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_chunks.append({
"type": "reply_chunk",
"source_type": message_source.type,
"source": message_source,
"bot_message": bot_message,
"message": message,
"quote_origin": quote_origin,
"is_final": is_final,
})
self._outbound_chunks.append(
{
'type': 'reply_chunk',
'source_type': message_source.type,
'source': message_source,
'bot_message': bot_message,
'message': message,
'quote_origin': quote_origin,
'is_final': is_final,
}
)
async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported."""
@@ -295,7 +303,7 @@ class FakePlatform:
def fake_platform(
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
) -> FakePlatform:
"""Create a FakePlatform instance."""
@@ -328,9 +336,7 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported)
adapter._fake_platform = platform # Store for assertions
return adapter
return adapter
+42 -38
View File
@@ -27,51 +27,51 @@ class FakeProvider:
Does not require API keys.
"""
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
def __init__(
self,
*,
default_response: str = "fake response",
default_response: str = 'fake response',
streaming_chunks: list[str] = None,
raise_error: Exception = None,
captured_requests: list = None,
):
self._default_response = default_response
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> "FakeProvider":
def returns(self, text: str) -> 'FakeProvider':
"""Configure provider to return a specific text response."""
self._default_response = text
self._streaming_chunks = [text]
return self
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
"""Configure provider to return streaming chunks."""
self._streaming_chunks = chunks
self._default_response = "".join(chunks)
self._default_response = ''.join(chunks)
return self
def raises(self, error: Exception) -> "FakeProvider":
def raises(self, error: Exception) -> 'FakeProvider':
"""Configure provider to raise an error."""
self._raise_error = error
return self
def timeout(self) -> "FakeProvider":
def timeout(self) -> 'FakeProvider':
"""Configure provider to simulate timeout."""
return self.raises(TimeoutError("Provider timeout"))
return self.raises(TimeoutError('Provider timeout'))
def auth_error(self) -> "FakeProvider":
def auth_error(self) -> 'FakeProvider':
"""Configure provider to simulate auth error."""
return self.raises(Exception("Invalid API key"))
return self.raises(Exception('Invalid API key'))
def rate_limit(self) -> "FakeProvider":
def rate_limit(self) -> 'FakeProvider':
"""Configure provider to simulate rate limit."""
return self.raises(Exception("Rate limit exceeded"))
return self.raises(Exception('Rate limit exceeded'))
def malformed(self) -> "FakeProvider":
def malformed(self) -> 'FakeProvider':
"""Configure provider to simulate malformed response."""
self._default_response = None
return self
@@ -87,7 +87,7 @@ class FakeProvider:
def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content."""
return provider_message.Message(
role="assistant",
role='assistant',
content=content,
)
@@ -99,7 +99,7 @@ class FakeProvider:
) -> provider_message.MessageChunk:
"""Create a provider message chunk."""
return provider_message.MessageChunk(
role="assistant",
role='assistant',
content=content,
is_final=is_final,
msg_sequence=msg_sequence,
@@ -116,13 +116,15 @@ class FakeProvider:
) -> provider_message.Message:
"""Simulate non-streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
}
)
# Simulate error if configured
if self._raise_error:
@@ -131,7 +133,7 @@ class FakeProvider:
# Return response
if self._default_response is None:
# Malformed response
return provider_message.Message(role="assistant", content=None)
return provider_message.Message(role='assistant', content=None)
return self._create_message(self._default_response)
@@ -146,14 +148,16 @@ class FakeProvider:
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
"streaming": True,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
'streaming': True,
}
)
# Simulate error if configured
if self._raise_error:
@@ -161,12 +165,12 @@ class FakeProvider:
# Yield chunks
for i, chunk in enumerate(self._streaming_chunks):
is_final = (i == len(self._streaming_chunks) - 1)
is_final = i == len(self._streaming_chunks) - 1
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider(
default_response: str = "fake response",
default_response: str = 'fake response',
) -> FakeProvider:
"""Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response)
@@ -202,8 +206,8 @@ def fake_provider_malformed() -> FakeProvider:
def fake_model(
*,
uuid: str = "test-model-uuid",
name: str = "test-model",
uuid: str = 'test-model-uuid',
name: str = 'test-model',
abilities: list[str] = None,
provider: FakeProvider = None,
) -> Mock:
@@ -212,7 +216,7 @@ def fake_model(
model.model_entity = Mock()
model.model_entity.uuid = uuid
model.model_entity.name = name
model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.abilities = abilities or ['func_call', 'vision']
model.model_entity.extra_args = {}
# Attach fake provider
@@ -221,4 +225,4 @@ def fake_model(
model.provider = provider
return model
return model
+1 -1
View File
@@ -3,4 +3,4 @@ Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""
"""
+1 -1
View File
@@ -2,4 +2,4 @@
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""
"""
+33 -34
View File
@@ -48,6 +48,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield
@@ -56,10 +57,12 @@ def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -71,28 +74,29 @@ def fake_bot_app():
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[
{
app.bot_service.get_bots = AsyncMock(
return_value=[
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
]
)
app.bot_service.get_runtime_bot_info = AsyncMock(
return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
}
])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
)
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1))
app.bot_service.send_message = AsyncMock()
# Platform manager
@@ -118,10 +122,7 @@ class TestBotEndpoints:
@pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'})
assert response.status_code == 200
data = await response.get_json()
@@ -135,7 +136,7 @@ class TestBotEndpoints:
response = await quart_test_client.post(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'},
)
assert response.status_code == 200
@@ -147,8 +148,7 @@ class TestBotEndpoints:
async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -162,7 +162,7 @@ class TestBotEndpoints:
response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}
json={'name': 'Updated Bot'},
)
assert response.status_code == 200
@@ -173,8 +173,7 @@ class TestBotEndpoints:
async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -190,7 +189,7 @@ class TestBotLogsEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}
json={'from_index': -1, 'max_count': 10},
)
assert response.status_code == 200
@@ -213,8 +212,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'person',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 200
@@ -228,7 +227,7 @@ class TestBotSendMessageEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]},
)
assert response.status_code == 400
@@ -244,8 +243,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'invalid',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 400
+22 -31
View File
@@ -47,6 +47,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield
@@ -55,10 +56,12 @@ def fake_embed_app():
"""Create FakeApp with embed widget services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock()
@@ -83,9 +86,7 @@ def fake_embed_app():
# WebSocket proxy bot with adapter
mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}])
mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock()
@@ -117,9 +118,7 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js')
assert response.status_code == 200
assert 'javascript' in response.content_type
@@ -127,18 +126,14 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js')
assert response.status_code == 400
@pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js')
assert response.status_code == 404
@@ -164,8 +159,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 200
@@ -177,8 +171,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 400
@@ -187,8 +180,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={}
)
assert response.status_code == 400
@@ -203,7 +195,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/person returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -216,7 +208,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/group returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -226,7 +218,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages with invalid session_type returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 400
@@ -241,7 +233,7 @@ class TestEmbedResetEndpoint:
"""POST reset/person resets session."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -252,8 +244,7 @@ class TestEmbedResetEndpoint:
async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@@ -269,7 +260,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}
json={'message_id': 'msg-123', 'feedback_type': 1},
)
assert response.status_code == 200
@@ -283,7 +274,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}
json={'message_id': 'msg-123', 'feedback_type': 2},
)
assert response.status_code == 200
@@ -294,7 +285,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}
json={'message_id': 'msg-123', 'feedback_type': 99},
)
assert response.status_code == 400
+36 -38
View File
@@ -49,6 +49,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,33 +75,35 @@ def fake_knowledge_app():
# Knowledge service
app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
{
app.knowledge_service.get_knowledge_bases = AsyncMock(
return_value=[
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
]
)
app.knowledge_service.get_knowledge_base = AsyncMock(
return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
)
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
])
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}]
)
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}])
# RAG manager
app.rag_mgr = Mock()
@@ -124,8 +129,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -140,7 +144,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'},
)
assert response.status_code == 200
@@ -152,8 +156,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -167,7 +170,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}
json={'name': 'Updated KB'},
)
assert response.status_code == 200
@@ -178,8 +181,7 @@ class TestKnowledgeBaseEndpoints:
async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -193,8 +195,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -208,7 +209,7 @@ class TestKnowledgeBaseFilesEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'},
)
assert response.status_code == 200
@@ -220,8 +221,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}},
)
assert response.status_code == 200
@@ -249,9 +249,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={}
)
assert response.status_code == 400
+49 -67
View File
@@ -46,6 +46,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield
@@ -54,10 +55,12 @@ def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock()
@@ -67,40 +70,34 @@ def fake_monitoring_app():
# Monitoring service
app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
})
app.monitoring_service.get_messages = AsyncMock(return_value=(
[{'id': 'msg-1', 'content': 'test'}], 100
))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
[{'id': 'llm-1'}], 50
))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
[{'id': 'emb-1'}], 10
))
app.monitoring_service.get_sessions = AsyncMock(return_value=(
[{'session_id': 'sess-1'}], 20
))
app.monitoring_service.get_errors = AsyncMock(return_value=(
[{'id': 'err-1'}], 2
))
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'found': True,
'session_id': 'sess-1',
})
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_overview_metrics = AsyncMock(
return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
}
)
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10))
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20))
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2))
app.monitoring_service.get_session_analysis = AsyncMock(
return_value={
'found': True,
'session_id': 'sess-1',
}
)
app.monitoring_service.get_message_details = AsyncMock(
return_value={
'found': True,
'message_id': 'msg-1',
}
)
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
@@ -130,8 +127,7 @@ class TestMonitoringOverviewEndpoint:
async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get(
'/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -147,8 +143,7 @@ class TestMonitoringMessagesEndpoint:
async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -165,8 +160,7 @@ class TestMonitoringLLMCallsEndpoint:
async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -180,8 +174,7 @@ class TestMonitoringEmbeddingCallsEndpoint:
async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -195,8 +188,7 @@ class TestMonitoringSessionsEndpoint:
async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -210,8 +202,7 @@ class TestMonitoringErrorsEndpoint:
async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors."""
response = await quart_test_client.get(
'/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -225,8 +216,7 @@ class TestMonitoringAllDataEndpoint:
async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get(
'/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -242,8 +232,7 @@ class TestMonitoringDetailsEndpoints:
async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -252,8 +241,7 @@ class TestMonitoringDetailsEndpoints:
async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -267,8 +255,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -277,8 +264,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -292,8 +278,7 @@ class TestMonitoringExportEndpoint:
async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -303,8 +288,7 @@ class TestMonitoringExportEndpoint:
async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -313,8 +297,7 @@ class TestMonitoringExportEndpoint:
async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -323,8 +306,7 @@ class TestMonitoringExportEndpoint:
async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
+32 -42
View File
@@ -49,6 +49,7 @@ def mock_circular_import_chain():
):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,27 +75,23 @@ def fake_provider_app():
# Provider service
app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock(return_value=[
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
])
app.provider_service.get_provider = AsyncMock(return_value={
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
})
app.provider_service.get_providers = AsyncMock(
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}]
)
app.provider_service.get_provider = AsyncMock(
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
)
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
})
app.provider_service.get_provider_model_counts = AsyncMock(
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0}
)
# LLM model service
app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}])
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock()
@@ -133,8 +132,7 @@ class TestProviderEndpoints:
async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -157,8 +155,7 @@ class TestProviderEndpoints:
async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -177,7 +174,7 @@ class TestProviderEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}
json={'name': 'New Provider', 'requester': 'chatcmpl'},
)
assert response.status_code == 200
@@ -194,7 +191,7 @@ class TestProviderEndpoints:
response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}
json={'name': 'Updated Provider'},
)
assert response.status_code == 200
@@ -205,8 +202,7 @@ class TestProviderEndpoints:
async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -215,8 +211,7 @@ class TestProviderEndpoints:
async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,8 +232,7 @@ class TestModelEndpoints:
async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -250,8 +244,7 @@ class TestModelEndpoints:
async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -264,7 +257,7 @@ class TestModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -276,8 +269,7 @@ class TestModelEndpoints:
async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -291,8 +283,7 @@ class TestEmbeddingModelEndpoints:
async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -306,7 +297,7 @@ class TestEmbeddingModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -323,8 +314,7 @@ class TestRerankModelEndpoints:
async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -338,7 +328,7 @@ class TestRerankModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
+14 -12
View File
@@ -20,6 +20,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -69,6 +70,7 @@ def mock_circular_import_chain():
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
@@ -79,12 +81,14 @@ def fake_api_app():
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# API-specific services
app.user_service = Mock()
@@ -118,6 +122,7 @@ def fake_api_app():
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls):
"""
@@ -135,6 +140,7 @@ async def quart_test_client(fake_api_app, http_controller_cls):
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@@ -222,8 +228,7 @@ class TestProtectedEndpoints:
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@@ -254,10 +259,7 @@ class TestInvalidPayload:
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'})
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)
+1 -1
View File
@@ -2,4 +2,4 @@
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""
"""
@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
db_file = tmp_path / 'test_migrations.db'
return f'sqlite+aiosqlite:///{db_file}'
@pytest.fixture
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestSQLiteMigrationFreshDatabase:
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_migrations_fresh.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
assert rev is not None, 'Expected a revision on fresh DB'
await fresh_engine.dispose()
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_empty_migrations.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
# Verify specific behavior - one of two outcomes is expected
if actual_result is not None:
# Migration succeeded - verify revision exists
assert actual_result is not None, "Revision should exist after successful migration"
assert actual_result is not None, 'Revision should exist after successful migration'
else:
# Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables
assert actual_error is not None, "Error should be captured if migration failed"
assert actual_error is not None, 'Error should be captured if migration failed'
# Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios
acceptable_errors = [
'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors
'CommandError', # Alembic command errors
]
assert error_type in acceptable_errors, (
f"Unexpected error type: {error_type}. "
f"This may indicate a regression in migration behavior. "
f"Error: {actual_error}"
f'Unexpected error type: {error_type}. '
f'This may indicate a regression in migration behavior. '
f'Error: {actual_error}'
)
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'
@@ -34,14 +34,14 @@ def postgres_url():
"""Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL')
if not url:
pytest.skip("TEST_POSTGRES_URL not set")
pytest.skip('TEST_POSTGRES_URL not set')
return url
@pytest.fixture
async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT')
yield engine
await engine.dispose()
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -83,9 +83,7 @@ class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp baseline on existing tables sets correct revision.
@@ -106,9 +104,7 @@ class TestPostgreSQLMigrationBaseline:
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp on empty database (no tables) still sets revision.
@@ -125,9 +121,7 @@ class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Upgrade from baseline to head applies all migrations.
@@ -149,14 +143,12 @@ class TestPostgreSQLMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev is not None, "Expected a revision after upgrade"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration (0005 for current state)
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev.startswith('0005'), f'Expected head to be 0005_*, got {rev}'
@pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Running upgrade to head multiple times is idempotent.
@@ -180,7 +172,7 @@ class TestPostgreSQLMigrationUpgrade:
await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestPostgreSQLMigrationGetCurrent:
@@ -199,7 +191,7 @@ class TestPostgreSQLMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(postgres_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision(
@@ -214,4 +206,4 @@ class TestPostgreSQLMigrationGetCurrent:
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'
+1 -1
View File
@@ -2,4 +2,4 @@
Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner.
"""
"""
+41 -21
View File
@@ -26,6 +26,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -103,6 +104,7 @@ def mock_circular_import_chain():
# ============== FAKE RUNNER ==============
class FakeRunner:
"""Minimal fake runner class for pipeline integration tests.
@@ -117,12 +119,13 @@ class FakeRunner:
self.config = config or {}
self._provider = FakeProvider()
# Instance-level configuration set via class attribute
self._response_text = "fake response"
self._response_text = 'fake response'
self._raise_error = None
@classmethod
def returns(cls, text: str):
"""Create a runner class configured to return specific text."""
# We create a subclass with configured response
class ConfiguredRunner(cls):
name = cls.name
@@ -132,11 +135,13 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._response_text = text
return ConfiguredRunner
@classmethod
def raises(cls, error: Exception):
"""Create a runner class configured to raise an error."""
class ConfiguredRunner(cls):
name = cls.name
_response_text = None
@@ -145,6 +150,7 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._raise_error = error
return ConfiguredRunner
async def run(self, query):
@@ -161,6 +167,7 @@ class FakeRunner:
# ============== PIPELINE APP FIXTURE ==============
@pytest.fixture
def pipeline_app():
"""
@@ -187,6 +194,7 @@ def pipeline_app():
def __init__(self, name, messages):
self.name = name
self.messages = messages
def copy(self):
return MockPrompt(self.name, list(self.messages))
@@ -237,14 +245,17 @@ def fake_platform_adapter():
@pytest.fixture
def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner
# ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests."""
return {
@@ -273,6 +284,7 @@ def create_minimal_pipeline_config():
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name):
"""
Helper to handle the coroutine -> async_generator pattern.
@@ -296,6 +308,7 @@ async def collect_processor_results(processor, query, stage_name):
# ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain."""
@@ -337,7 +350,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
# Create query with adapter
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -365,7 +378,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
query = text_query("test message content")
query = text_query('test message content')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -396,11 +409,11 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Set fake runner that returns pong
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
@@ -414,6 +427,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -432,7 +446,7 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -445,6 +459,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -462,13 +477,13 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Create reply chain
reply_chain = text_chain("plugin response")
reply_chain = text_chain('plugin response')
# Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock()
@@ -479,6 +494,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -502,7 +518,7 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded'))
set_fake_runner(fake_runner)
# Create query with exception handling config
@@ -510,7 +526,7 @@ class TestRunnerExceptionFlow:
config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -523,6 +539,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -541,14 +558,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error'))
set_fake_runner(fake_runner)
# Create query with show-error mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -561,6 +578,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -578,14 +596,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception("Hidden error"))
fake_runner = FakeRunner().raises(Exception('Hidden error'))
set_fake_runner(fake_runner)
# Create query with hide mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -598,6 +616,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -623,7 +642,7 @@ class TestSendResponseBackStage:
adapter, platform = fake_platform_adapter
# Create query with response message
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -666,12 +685,12 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Set fake runner
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
config = create_minimal_pipeline_config()
query = text_query("ping")
query = text_query('ping')
query.adapter = adapter
query.pipeline_config = config
query.resp_messages = []
@@ -690,7 +709,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -711,6 +730,7 @@ class TestStageChainIntegration:
# Build resp_message_chain from resp_messages
from tests.factories.message import text_chain
for resp_msg in query.resp_messages:
if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content))
@@ -737,7 +757,7 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -754,7 +774,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -775,4 +795,4 @@ class TestStageChainIntegration:
assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages
assert len(query.resp_messages) == 0
assert len(query.resp_messages) == 0
+1 -1
View File
@@ -3,4 +3,4 @@ Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q
"""
"""
+67 -64
View File
@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
assert app.instance_config is not None
# Verify default config
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data["command"]["enable"] is True
assert app.instance_config.data['command']['prefix'] == ['/', '!']
assert app.instance_config.data['command']['enable'] is True
@pytest.mark.asyncio
async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response."""
provider = FakeProvider(default_response="test response")
provider = FakeProvider(default_response='test response')
# Create mock model with provider
model = fake_model(provider=provider)
# Create a simple query
query = text_query("hello")
query = text_query('hello')
# Simulate invoke
result = await provider.invoke_llm(
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
)
assert result is not None
assert result.role == "assistant"
assert result.content == "test response"
assert result.role == 'assistant'
assert result.content == 'test response'
@pytest.mark.asyncio
async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong()
model = fake_model(provider=provider)
query = text_query("ping")
query = text_query('ping')
result = await provider.invoke_llm(
query=query,
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
@pytest.mark.asyncio
async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(["Hello", " World"])
provider = FakeProvider().returns_streaming(['Hello', ' World'])
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
chunks = []
# invoke_llm_stream returns an async generator, don't await it
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].content == "Hello"
assert chunks[1].content == " World"
assert chunks[0].content == 'Hello'
assert chunks[1].content == ' World'
assert chunks[1].is_final is True
@pytest.mark.asyncio
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(TimeoutError, match="Provider timeout"):
with pytest.raises(TimeoutError, match='Provider timeout'):
await provider.invoke_llm(
query=query,
model=model,
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(Exception, match="Rate limit exceeded"):
with pytest.raises(Exception, match='Rate limit exceeded'):
await provider.invoke_llm(
query=query,
model=model,
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments."""
provider = FakeProvider()
model = fake_model(name="gpt-4", provider=provider)
query = text_query("hello")
model = fake_model(name='gpt-4', provider=provider)
query = text_query('hello')
await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "hello"}],
funcs=[{"name": "test_func"}],
extra_args={"temperature": 0.7},
messages=[{'role': 'user', 'content': 'hello'}],
funcs=[{'name': 'test_func'}],
extra_args={'temperature': 0.7},
)
captured = provider.get_captured_requests()
assert len(captured) == 1
assert captured[0]["model"] == "gpt-4"
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]["extra_args"] == {"temperature": 0.7}
assert captured[0]['model'] == 'gpt-4'
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}]
assert captured[0]['funcs'] == [{'name': 'test_func'}]
assert captured[0]['extra_args'] == {'temperature': 0.7}
@pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id="test-bot")
query = text_query("hello")
platform = FakePlatform(bot_account_id='test-bot')
query = text_query('hello')
# Simulate sending reply
from tests.factories.message import text_chain
reply_chain = text_chain("response text")
reply_chain = text_chain('response text')
event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False)
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
# Verify captured
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert outbound[0]["message"] == reply_chain
assert outbound[0]['type'] == 'reply'
assert outbound[0]['message'] == reply_chain
@pytest.mark.asyncio
async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_friend_message(
text="hello bot",
text='hello bot',
sender_id=12345,
nickname="TestUser",
nickname='TestUser',
)
assert event.type == "FriendMessage"
assert event.type == 'FriendMessage'
assert event.sender.id == 12345
assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == "hello bot"
assert event.sender.nickname == 'TestUser'
assert str(event.message_chain) == 'hello bot'
@pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_group_message(
text="hello everyone",
text='hello everyone',
sender_id=12345,
group_id=99999,
mention_bot=True,
)
assert event.type == "GroupMessage"
assert event.type == 'GroupMessage'
assert event.sender.id == 12345
assert event.group.id == 99999
@@ -220,54 +220,57 @@ class TestFakeMessageFlow:
async def test_query_factories_basic(self):
"""Test basic query factory functions."""
# Text query
q1 = text_query("hello world")
assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == "hello world"
q1 = text_query('hello world')
assert q1.launcher_type.value == 'person'
assert str(q1.message_chain) == 'hello world'
# Group query
from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
assert q2.launcher_type.value == "group"
q2 = group_text_query('hello group', group_id=88888)
assert q2.launcher_type.value == 'group'
assert q2.launcher_id == 88888
# Command query
from tests.factories import command_query
q3 = command_query("help", prefix="/")
assert str(q3.message_chain) == "/help"
q3 = command_query('help', prefix='/')
assert str(q3.message_chain) == '/help'
# Mention query
from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
assert q4.launcher_type.value == "group"
q4 = mention_query('hi', target='test-bot', group_id=77777)
assert q4.launcher_type.value == 'group'
@pytest.mark.asyncio
async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure()
query = text_query("hello")
query = text_query('hello')
from tests.factories.message import text_chain
with pytest.raises(Exception, match="Platform send failure"):
with pytest.raises(Exception, match='Platform send failure'):
await platform.reply_message(
query.message_event,
text_chain("response"),
text_chain('response'),
)
@pytest.mark.asyncio
async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id="bot-123")
platform = FakePlatform(bot_account_id='bot-123')
adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == "bot-123"
assert adapter.bot_account_id == 'bot-123'
assert adapter._fake_platform is platform
# Test reply_message is wired
from tests.factories.message import text_chain
query = text_query("test")
await adapter.reply_message(query.message_event, text_chain("response"))
query = text_query('test')
await adapter.reply_message(query.message_event, text_chain('response'))
# Verify platform captured it
assert len(platform.get_outbound_messages()) == 1
@@ -293,18 +296,18 @@ class TestMessageFlowIntegration:
Note: This does NOT run actual LangBot pipeline stages.
"""
# Setup
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
provider = fake_provider_pong()
model = fake_model(provider=provider)
# Create inbound message
query = text_query("ping")
query = text_query('ping')
# Simulate provider processing
response = await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "ping"}],
messages=[{'role': 'user', 'content': 'ping'}],
funcs=[],
extra_args={},
)
@@ -321,16 +324,16 @@ class TestMessageFlowIntegration:
# Verify platform captured outbound
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
assert outbound[0]['type'] == 'reply'
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(["Hello", " there"])
provider = FakeProvider().returns_streaming(['Hello', ' there'])
model = fake_model(provider=provider)
query = text_query("hi")
query = text_query('hi')
chunks = []
async for chunk in provider.invoke_llm_stream(
@@ -344,8 +347,8 @@ class TestMessageFlowIntegration:
# Verify streaming worked
assert len(chunks) == 2
full_content = "".join(c.content for c in chunks)
assert full_content == "Hello there"
full_content = ''.join(c.content for c in chunks)
assert full_content == 'Hello there'
# Verify platform supports streaming
assert await platform.is_stream_output_supported() is True
assert await platform.is_stream_output_supported() is True
+10 -21
View File
@@ -15,22 +15,12 @@ import pathlib
# Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py'
def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
tree = ast.parse(source)
@@ -40,27 +30,26 @@ def test_no_exec_call_in_system_controller():
if isinstance(node, ast.Call):
func = node.func
# Match bare exec() call
if isinstance(func, ast.Name) and func.id == "exec":
if isinstance(func, ast.Name) and func.id == 'exec':
exec_calls.append(node.lineno)
assert len(exec_calls) == 0, (
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().'
)
def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
assert "debug/exec" not in source, (
"The /debug/exec route still exists in system.py. "
"This endpoint allows arbitrary code execution and must be removed."
assert 'debug/exec' not in source, (
'The /debug/exec route still exists in system.py. '
'This endpoint allows arbitrary code execution and must be removed.'
)
if __name__ == "__main__":
if __name__ == '__main__':
test_no_exec_call_in_system_controller()
test_no_debug_exec_route()
print("All tests passed!")
print('All tests passed!')
+1 -1
View File
@@ -1 +1 @@
"""Unit tests for LangBot API HTTP service layer."""
"""Unit tests for LangBot API HTTP service layer."""
+1 -1
View File
@@ -13,4 +13,4 @@ Does NOT:
- Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application.
"""
"""
@@ -132,9 +132,7 @@ class TestApiKeyServiceCreateApiKey:
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}]
assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key'
@@ -303,13 +303,7 @@ class TestBotServiceCreateBot:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
@@ -318,9 +312,7 @@ class TestBotServiceCreateBot:
bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'})
service = BotService(ap)
@@ -352,6 +344,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -362,9 +355,7 @@ class TestBotServiceCreateBot:
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'})
service = BotService(ap)
@@ -397,6 +388,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -492,6 +484,7 @@ class TestBotServiceUpdateBot:
pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -582,10 +575,9 @@ class TestBotServiceListEventLogs:
# Mock runtime bot with logger
runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock(return_value=(
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
5
))
runtime_bot.logger.get_logs = AsyncMock(
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5)
)
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
@@ -646,11 +638,7 @@ class TestBotServiceSendMessage:
service = BotService(ap)
# Execute with valid message chain format
message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]}
# Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
@@ -6,6 +6,7 @@ Tests cover:
- Knowledge engine discovery
- File operations
"""
from __future__ import annotations
import pytest
@@ -52,9 +53,7 @@ class TestGetKnowledgeBases:
"""Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}])
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
@@ -83,9 +82,7 @@ class TestGetKnowledgeBase:
"""Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'})
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1')
@@ -153,9 +150,7 @@ class TestCreateKnowledgeBase:
service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'})
# Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
@@ -170,20 +165,21 @@ class TestUpdateKnowledgeBase:
"""Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'})
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields
await service.update_knowledge_base('kb1', {
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
})
await service.update_knowledge_base(
'kb1',
{
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
},
)
# Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args
@@ -288,9 +284,7 @@ class TestListKnowledgeEngines:
"""Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
@@ -386,12 +380,10 @@ class TestGetEngineSchemas:
"""Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert result == {}
mock_app.logger.warning.assert_called_once()
mock_app.logger.warning.assert_called_once()
@@ -174,9 +174,7 @@ class TestMaintenanceServiceGetStorageAnalysis:
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
@@ -292,12 +290,8 @@ class TestMaintenanceServiceGetStorageAnalysis:
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[
{'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}])
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}])
# Execute
result = await service.get_storage_analysis()
@@ -367,6 +361,7 @@ class TestMaintenanceServiceBinaryStorageStats:
size_result = _create_mock_result(scalar_value=5000)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +391,7 @@ class TestMaintenanceServiceBinaryStorageStats:
count_result = _create_mock_result(scalar_value=5)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -821,4 +817,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included
assert 'path' in result[0]
assert 'path' in result[0]
@@ -186,13 +186,7 @@ class TestMCPServiceCreateMCPServer:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}}
ap.plugin_connector = SimpleNamespace()
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
@@ -252,6 +246,7 @@ class TestMCPServiceCreateMCPServer:
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -361,6 +356,7 @@ class TestMCPServiceUpdateMCPServer:
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -394,6 +390,7 @@ class TestMCPServiceUpdateMCPServer:
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -432,6 +429,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock for: first select -> update -> second select (for updated server)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -465,6 +463,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock execute for select and update
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -499,6 +498,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Server to Delete')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -530,6 +530,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Not in Sessions')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -559,6 +560,7 @@ class TestMCPServiceDeleteMCPServer:
# No server found
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -596,9 +598,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=123)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123))
service = MCPService(ap)
@@ -634,9 +634,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=456)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456))
service = MCPService(ap)
@@ -645,4 +643,4 @@ class TestMCPServiceTestMCPServer:
# Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456
assert task_id == 456
@@ -167,6 +167,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([])
call_count = 0
async def mock_execute(query):
return mock_result if call_count == 0 else mock_provider_result
@@ -200,6 +201,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -239,6 +241,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -279,6 +282,7 @@ class TestLLMModelsServiceGetLLMModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -337,9 +341,7 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'model-1', 'name': 'Model 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'})
service = LLMModelsService(ap)
@@ -371,12 +373,14 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
})
model_uuid = await service.create_llm_model(
{
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -400,13 +404,16 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}, preserve_uuid=True)
model_uuid = await service.create_llm_model(
{
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
},
preserve_uuid=True,
)
# Verify
assert model_uuid == 'preserved-uuid'
@@ -459,12 +466,14 @@ class TestLLMModelsServiceCreateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_llm_model({
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
})
await service.create_llm_model(
{
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
}
)
async def test_create_llm_model_with_provider_data(self):
"""Creates provider when provider data provided."""
@@ -490,16 +499,18 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute - with provider data (no UUID)
result_uuid = await service.create_llm_model({
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
})
result_uuid = await service.create_llm_model(
{
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
}
)
# Verify - provider_service was called and UUID generated
ap.provider_service.find_or_create_provider.assert_called_once()
@@ -525,11 +536,14 @@ class TestLLMModelsServiceUpdateLLMModel:
service = LLMModelsService(ap)
# Execute
await service.update_llm_model('existing-uuid', {
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
})
await service.update_llm_model(
'existing-uuid',
{
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
},
)
# Verify - remove and load called
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
@@ -549,10 +563,13 @@ class TestLLMModelsServiceUpdateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.update_llm_model('model-uuid', {
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
})
await service.update_llm_model(
'model-uuid',
{
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
},
)
async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args."""
@@ -618,9 +635,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'})
service = EmbeddingModelsService(ap)
@@ -643,6 +658,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -683,6 +699,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -742,11 +759,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
service = EmbeddingModelsService(ap)
# Execute
model_uuid = await service.create_embedding_model({
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_embedding_model(
{
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -767,11 +786,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_embedding_model({
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_embedding_model(
{
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
@@ -829,6 +850,7 @@ class TestRerankModelsServiceGetRerankModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -869,6 +891,7 @@ class TestRerankModelsServiceGetRerankModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -928,11 +951,13 @@ class TestRerankModelsServiceCreateRerankModel:
service = RerankModelsService(ap)
# Execute
model_uuid = await service.create_rerank_model({
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_rerank_model(
{
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -952,11 +977,13 @@ class TestRerankModelsServiceCreateRerankModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_rerank_model({
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_rerank_model(
{
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestRerankModelsServiceDeleteRerankModel:
@@ -995,9 +1022,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'})
service = EmbeddingModelsService(ap)
@@ -1022,9 +1047,7 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'})
service = RerankModelsService(ap)
@@ -1042,14 +1065,10 @@ class TestValidateProviderSupports:
def _make_ap(requester_name: str, support_type):
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
manifest = SimpleNamespace(spec={'support_type': support_type})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester=requester_name)
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest
if name == requester_name
else None,
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None,
)
return SimpleNamespace(model_mgr=model_mgr)
@@ -1066,9 +1085,7 @@ class TestValidateProviderSupports:
async def test_allows_when_support_type_missing(self):
# Manifest without support_type must not block (backward compatible)
manifest = SimpleNamespace(spec={})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester='legacy')
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy'))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest,
@@ -215,13 +215,7 @@ class TestPipelineServiceCreatePipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -229,9 +223,7 @@ class TestPipelineServiceCreatePipeline:
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'})
service = PipelineService(ap)
@@ -258,14 +250,14 @@ class TestPipelineServiceCreatePipeline:
# Mock persistence for insert
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
# Mock the file read for default config - patch at the utils module level
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
# Verify
@@ -286,7 +278,9 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
service.get_pipeline = AsyncMock(
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
)
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
@@ -296,7 +290,9 @@ class TestPipelineServiceCreatePipeline:
# Mock the file read
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
# Verify - execute was called
@@ -316,10 +312,12 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
})
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
}
)
insert_params = []
@@ -339,7 +337,9 @@ class TestPipelineServiceCreatePipeline:
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'New Pipeline'})
assert len(insert_params) == 1
@@ -353,6 +353,7 @@ class TestPipelineServiceCreatePipeline:
class _MockResultWithBots:
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
def __init__(self, bots_list):
self._bots_list = bots_list
@@ -428,6 +429,7 @@ class TestPipelineServiceUpdatePipeline:
# 1. UPDATE (line 125) - returns Mock (no result needed)
# 2. SELECT bots (line 136) - returns bot_result with .all()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -528,13 +530,7 @@ class TestPipelineServiceCopyPipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -542,10 +538,12 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
# Mock get_pipelines to return 2 pipelines
service.get_pipelines = AsyncMock(return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
])
service.get_pipelines = AsyncMock(
return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
]
)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'):
@@ -642,9 +640,7 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'copy-uuid', 'is_default': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False})
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
@@ -681,11 +677,10 @@ class TestPipelineServiceUpdatePipelineExtensions:
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
)
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []})
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -700,7 +695,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -711,7 +706,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -738,6 +733,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
original_pipeline = _create_mock_pipeline()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -752,7 +748,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_mcp_servers': False,
'mcp_servers': ['mcp-server-1'],
}
},
}
)
@@ -794,6 +790,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -245,12 +245,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
provider_uuid = await service.create_provider({
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
provider_uuid = await service.create_provider(
{
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - UUID is generated
assert provider_uuid is not None
@@ -274,12 +276,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
result_uuid = await service.create_provider({
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
result_uuid = await service.create_provider(
{
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - provider added to runtime dict and UUID generated
ap.model_mgr.load_provider.assert_called_once()
@@ -302,10 +306,13 @@ class TestModelProviderServiceUpdateProvider:
service = ModelProviderService(ap)
# Execute
await service.update_provider('existing-uuid', {
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
})
await service.update_provider(
'existing-uuid',
{
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
},
)
# Verify - reload called
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
@@ -364,6 +371,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=None)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +404,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -454,6 +463,7 @@ class TestModelProviderServiceGetProviderModelCounts:
rerank_result.scalar = Mock(return_value=1)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -637,9 +647,7 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
await service.update_space_model_provider_api_keys('space-api-key')
# Verify - update and reload called for Space provider UUID
ap.model_mgr.reload_provider.assert_called_once_with(
'00000000-0000-0000-0000-000000000000'
)
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000')
class TestModelProviderServiceScanProviderModels:
@@ -795,9 +803,7 @@ class TestModelProviderServiceScanProviderModels:
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
runtime_provider.requester.scan_models = AsyncMock(
side_effect=NotImplementedError('scan not supported')
)
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported'))
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
service = ModelProviderService(ap)
@@ -848,9 +854,7 @@ class TestModelProviderServiceScanProviderModels:
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing LLM model
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
return_value=[{'name': 'Existing Model'}]
)
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}])
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
@@ -863,4 +867,4 @@ class TestModelProviderServiceScanProviderModels:
assert existing_model['already_added'] is True
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
assert new_model['already_added'] is False
assert new_model['already_added'] is False
@@ -393,14 +393,16 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -429,10 +431,12 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 1,
'msg': 'Invalid refresh token',
})
mock_response.json = AsyncMock(
return_value={
'code': 1,
'msg': 'Invalid refresh token',
}
)
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
@@ -489,14 +493,16 @@ class TestSpaceServiceExchangeOAuthCode:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -555,13 +561,15 @@ class TestSpaceServiceGetUserInfoRaw:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -669,27 +677,29 @@ class TestSpaceServiceGetModels:
# Mock HTTP response with proper model data matching SpaceModel schema
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -775,4 +785,4 @@ class TestSpaceServiceCreditsCache:
# Verify - cache updated
assert result == 500
assert 'test@example.com' in service._credits_cache
assert service._credits_cache['test@example.com'][0] == 500
assert service._credits_cache['test@example.com'][0] == 500
@@ -495,6 +495,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -565,6 +566,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -605,4 +607,4 @@ class TestUserServiceCreateUserLock:
# Verify lock exists
assert hasattr(service, '_create_user_lock')
assert service._create_user_lock is not None
assert service._create_user_lock is not None
@@ -132,6 +132,7 @@ class TestWebhookServiceCreateWebhook:
# execute_async returns different results
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -181,6 +182,7 @@ class TestWebhookServiceCreateWebhook:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -217,6 +219,7 @@ class TestWebhookServiceCreateWebhook:
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -225,9 +228,7 @@ class TestWebhookServiceCreateWebhook:
return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'id': 1, 'enabled': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False})
service = WebhookService(ap)
@@ -503,4 +504,4 @@ class TestWebhookServiceGetEnabledWebhooks:
result = await service.get_enabled_webhooks()
# Verify - should be empty (SQL would filter disabled)
assert result == []
assert result == []
+256 -4
View File
@@ -407,7 +407,9 @@ def test_box_service_forced_template_ignores_pipeline_config():
launcher_type='person',
launcher_id='test_user',
sender_id='test_user',
pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
pipeline_config={
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
},
)
assert service.resolve_box_session_id(query) == 'global'
@@ -1527,9 +1529,7 @@ class TestBuildSkillExtraMounts:
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
]
# No skill is dropped, so no "missing" warning should be logged.
assert not any(
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
)
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list)
def test_skips_skill_with_empty_package_root(self):
logger = Mock()
@@ -1556,3 +1556,255 @@ class TestBuildSkillExtraMounts:
service = BoxService(app, client=Mock(spec=BoxRuntimeClient))
assert service.build_skill_extra_mounts(make_query()) == []
# ── Attachment passthrough (inbound / outbound) ─────────────────────────────
class TestAttachmentHelpers:
def test_sanitize_attachment_name_strips_traversal(self):
assert BoxService._sanitize_attachment_name('../../etc/passwd', 'fb') == 'passwd'
assert BoxService._sanitize_attachment_name('/a/b/c.png', 'fb') == 'c.png'
assert BoxService._sanitize_attachment_name('a b c.txt', 'fb') == 'a_b_c.txt'
assert BoxService._sanitize_attachment_name('', 'fallback.bin') == 'fallback.bin'
assert BoxService._sanitize_attachment_name('...', 'fb.bin') == 'fb.bin'
# weird unicode / shell chars dropped, but keeps a usable name
out = BoxService._sanitize_attachment_name('rm -rf $(x).png', 'fb')
assert '/' not in out and '$' not in out and out.endswith('.png')
def test_classify_outbound_entries_by_extension(self):
entries = [
{'name': 'chart.png', 'b64': 'AAA'},
{'name': 'clip.mp3', 'b64': 'BBB'},
{'name': 'report.pdf', 'b64': 'CCC'},
{'name': 'sub/dir/photo.JPG', 'b64': 'DDD'},
{'name': 'noext', 'b64': 'EEE'},
{'name': 'skip', 'b64': ''}, # dropped (no payload)
]
out = BoxService._classify_outbound_entries(entries)
by_name = {a['name']: a for a in out}
assert by_name['chart.png']['type'] == 'Image'
assert by_name['chart.png']['base64'].startswith('data:image/png;base64,')
assert by_name['clip.mp3']['type'] == 'Voice'
assert by_name['clip.mp3']['base64'].startswith('data:audio/mp3;base64,')
assert by_name['report.pdf']['type'] == 'File'
assert by_name['report.pdf']['base64'] == 'CCC' # raw b64, no data: prefix
# nested path collapses to basename, case-insensitive ext
assert by_name['photo.JPG']['type'] == 'Image'
assert by_name['noext']['type'] == 'File'
assert 'skip' not in by_name
@pytest.mark.asyncio
async def test_component_to_bytes_from_data_uri(self):
import base64
raw = b'hello-bytes'
data_uri = 'data:text/plain;base64,' + base64.b64encode(raw).decode()
component = SimpleNamespace(base64=data_uri, url=None, path=None)
result = await BoxService._component_to_bytes(component)
assert result is not None
data, mime = result
assert data == raw
assert mime == 'text/plain'
@pytest.mark.asyncio
async def test_component_to_bytes_returns_none_when_empty(self):
component = SimpleNamespace(base64=None, url=None, path=None)
assert await BoxService._component_to_bytes(component) is None
class TestInboundOutboundRoundTrip:
def _service(self) -> BoxService:
service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient))
service._available = True
return service
@pytest.mark.asyncio
async def test_materialize_inbound_writes_and_describes(self):
import base64
import langbot_plugin.api.entities.builtin.platform.message as platform_message
service = self._service()
img_bytes = b'\x89PNG\r\n\x1a\n fake png'
img_b64 = 'data:image/png;base64,' + base64.b64encode(img_bytes).decode()
query = make_query()
query.message_chain = platform_message.MessageChain(
[
platform_message.Plain(text='please resize this'),
platform_message.Image(base64=img_b64),
]
)
# Mock the sandbox write path: echo back the written paths.
async def fake_execute_tool(parameters, q):
assert '/workspace/inbox/' in parameters['command']
return {
'ok': True,
'stdout': '["/workspace/inbox/42/image_1.png"]',
'stderr': '',
}
service.execute_tool = AsyncMock(side_effect=fake_execute_tool)
descriptors = await service.materialize_inbound_attachments(query)
assert len(descriptors) == 1
d = descriptors[0]
assert d['type'] == 'Image'
assert d['path'] == '/workspace/inbox/42/image_1.png'
assert d['size'] == len(img_bytes)
@pytest.mark.asyncio
async def test_materialize_inbound_noop_without_attachments(self):
import langbot_plugin.api.entities.builtin.platform.message as platform_message
service = self._service()
query = make_query()
query.message_chain = platform_message.MessageChain([platform_message.Plain(text='just text')])
service.execute_tool = AsyncMock()
assert await service.materialize_inbound_attachments(query) == []
service.execute_tool.assert_not_called()
@pytest.mark.asyncio
async def test_collect_outbound_reads_and_clears(self):
service = self._service()
query = make_query()
calls = []
async def fake_execute_tool(parameters, q):
calls.append(parameters['command'])
if 'os.walk' in parameters['command']:
return {
'ok': True,
'stdout': '[{"name": "out.png", "b64": "QUJD"}]',
'stderr': '',
}
# the rm -rf cleanup call
return {'ok': True, 'stdout': '', 'stderr': ''}
service.execute_tool = AsyncMock(side_effect=fake_execute_tool)
attachments = await service.collect_outbound_attachments(query)
assert len(attachments) == 1
assert attachments[0]['type'] == 'Image'
assert attachments[0]['name'] == 'out.png'
# cleanup (rm -rf) must have been issued after a successful collection
assert any('rm -rf' in c for c in calls)
@pytest.mark.asyncio
async def test_collect_outbound_empty_no_cleanup(self):
service = self._service()
query = make_query()
calls = []
async def fake_execute_tool(parameters, q):
calls.append(parameters['command'])
return {'ok': True, 'stdout': '[]', 'stderr': ''}
service.execute_tool = AsyncMock(side_effect=fake_execute_tool)
assert await service.collect_outbound_attachments(query) == []
assert not any('rm -rf' in c for c in calls)
@pytest.mark.asyncio
async def test_passthrough_noop_when_unavailable(self):
service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient))
service._available = False
query = make_query()
assert await service.materialize_inbound_attachments(query) == []
assert await service.collect_outbound_attachments(query) == []
class TestAttachmentHostPath:
"""Direct host-filesystem transfer path (bind-mounted workspace).
When ``default_workspace`` is a real local dir, inbound/outbound bypass the
exec channel entirely (no ARG_MAX / stdout-truncation limits) and read/write
the bind-mounted host dir directly.
"""
def _service_with_workspace(self, tmp_path):
ws = str(tmp_path / 'box' / 'default')
os.makedirs(ws, exist_ok=True)
app = make_app(Mock(), allowed_mount_roots=[str(tmp_path)], host_root=str(tmp_path / 'box'))
service = BoxService(app, client=Mock(spec=BoxRuntimeClient))
service._available = True
# Force the default_workspace to our tmp dir so _host_query_dir resolves.
service.default_workspace = ws
return service, ws
@pytest.mark.asyncio
async def test_inbound_writes_to_host_no_exec(self, tmp_path):
import base64
import langbot_plugin.api.entities.builtin.platform.message as platform_message
service, ws = self._service_with_workspace(tmp_path)
# Big payload that would blow ARG_MAX on the exec path:
big = b'\x89PNG\r\n\x1a\n' + b'x' * (300 * 1024)
b64 = 'data:image/png;base64,' + base64.b64encode(big).decode()
query = make_query()
query.message_chain = platform_message.MessageChain([platform_message.Image(base64=b64)])
# execute_tool must NOT be called on the host path.
service.execute_tool = AsyncMock(side_effect=AssertionError('exec must not be used on host path'))
descriptors = await service.materialize_inbound_attachments(query)
assert len(descriptors) == 1
d = descriptors[0]
assert d['type'] == 'Image'
assert d['size'] == len(big)
# File actually landed on the host workspace.
host_file = os.path.join(ws, 'inbox', str(query.query_id), d['name'])
assert os.path.isfile(host_file)
assert open(host_file, 'rb').read() == big
@pytest.mark.asyncio
async def test_inbound_host_clears_stale_query_dir(self, tmp_path):
import base64
import langbot_plugin.api.entities.builtin.platform.message as platform_message
service, ws = self._service_with_workspace(tmp_path)
# Seed a stale file under the same query_id (simulates webchat id reuse).
stale_dir = os.path.join(ws, 'inbox', '42')
os.makedirs(stale_dir, exist_ok=True)
open(os.path.join(stale_dir, 'image_1.png'), 'wb').write(b'STALE-OLD-IMAGE')
new = b'\x89PNG\r\n\x1a\n NEW'
b64 = 'data:image/png;base64,' + base64.b64encode(new).decode()
query = make_query(query_id=42)
query.message_chain = platform_message.MessageChain([platform_message.Image(base64=b64)])
service.execute_tool = AsyncMock()
descriptors = await service.materialize_inbound_attachments(query)
# The new write recreated the dir; the stale file is gone, new bytes present.
host_file = os.path.join(stale_dir, descriptors[0]['name'])
assert open(host_file, 'rb').read() == new
# No leftover content from the stale image.
assert b'STALE-OLD-IMAGE' not in open(host_file, 'rb').read()
@pytest.mark.asyncio
async def test_outbound_reads_host_and_clears(self, tmp_path):
service, ws = self._service_with_workspace(tmp_path)
query = make_query()
outbox = os.path.join(ws, 'outbox', str(query.query_id))
os.makedirs(outbox, exist_ok=True)
# A large file that would be truncated on the exec/stdout path:
big_png = b'\x89PNG\r\n\x1a\n' + b'y' * (400 * 1024)
open(os.path.join(outbox, 'result.png'), 'wb').write(big_png)
open(os.path.join(outbox, 'notes.txt'), 'wb').write(b'hello')
service.execute_tool = AsyncMock(side_effect=AssertionError('exec must not be used on host path'))
attachments = await service.collect_outbound_attachments(query)
by_name = {a['name']: a for a in attachments}
assert by_name['result.png']['type'] == 'Image'
assert by_name['notes.txt']['type'] == 'File'
# Full image survived (no truncation).
import base64
raw = base64.b64decode(by_name['result.png']['base64'].split(',', 1)[-1])
assert raw == big_png
# Outbox cleared after collection.
assert os.listdir(outbox) == []
+1 -1
View File
@@ -1 +1 @@
# Unit tests for command module
# Unit tests for command module
+1 -1
View File
@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
# Should yield CommandNotFoundError (no such command registered)
assert len(results) == 1
assert results[0].error is not None
assert results[0].error is not None
+2 -1
View File
@@ -197,6 +197,7 @@ class TestCommandOperatorBase:
op = TestOperator(None)
# Should not raise
import asyncio
asyncio.get_event_loop().run_until_complete(op.initialize())
def test_execute_is_abstract(self):
@@ -299,4 +300,4 @@ class TestMultipleOperators:
yield None
assert AdminOperator.lowest_privilege == 2
assert SubOperator.lowest_privilege == 1
assert SubOperator.lowest_privilege == 1
+29 -25
View File
@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
@pytest.mark.asyncio
async def test_valid_yaml_loads(self, tmp_path):
"""Valid YAML config should load correctly."""
config_file = tmp_path / "test_config.yaml"
config_file = tmp_path / 'test_config.yaml'
# Write valid YAML
config_file.write_text("""
@@ -51,7 +51,7 @@ settings:
@pytest.mark.asyncio
async def test_invalid_yaml_raises_error(self, tmp_path):
"""Invalid YAML should raise clear error."""
config_file = tmp_path / "invalid.yaml"
config_file = tmp_path / 'invalid.yaml'
# Write invalid YAML (unclosed bracket)
config_file.write_text("""
@@ -67,13 +67,13 @@ settings:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await yaml_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_config_creates_from_template(self, tmp_path):
"""Missing config file should be created from template."""
config_file = tmp_path / "new_config.yaml"
config_file = tmp_path / 'new_config.yaml'
# File doesn't exist yet
assert not config_file.exists()
@@ -92,7 +92,7 @@ settings:
@pytest.mark.asyncio
async def test_template_completion(self, tmp_path):
"""Config should be completed with template defaults."""
config_file = tmp_path / "partial.yaml"
config_file = tmp_path / 'partial.yaml'
# Write partial config missing some template keys
config_file.write_text("""
@@ -115,7 +115,7 @@ name: custom_name
@pytest.mark.asyncio
async def test_yaml_save(self, tmp_path):
"""YAML config can be saved."""
config_file = tmp_path / "save_test.yaml"
config_file = tmp_path / 'save_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -131,7 +131,7 @@ name: custom_name
def test_yaml_save_sync(self, tmp_path):
"""YAML config can be saved synchronously."""
config_file = tmp_path / "sync_save.yaml"
config_file = tmp_path / 'sync_save.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -151,14 +151,18 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_valid_json_loads(self, tmp_path):
"""Valid JSON config should load correctly."""
config_file = tmp_path / "test_config.json"
config_file = tmp_path / 'test_config.json'
# Write valid JSON
config_file.write_text(json.dumps({
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}))
config_file.write_text(
json.dumps(
{
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}
)
)
json_file = JSONConfigFile(
str(config_file),
@@ -174,7 +178,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_invalid_json_raises_error(self, tmp_path):
"""Invalid JSON should raise clear error."""
config_file = tmp_path / "invalid.json"
config_file = tmp_path / 'invalid.json'
# Write invalid JSON (missing closing brace)
config_file.write_text('{"name": "test", "unclosed": ')
@@ -184,13 +188,13 @@ class TestJSONConfigFile:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await json_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_json_creates_from_template(self, tmp_path):
"""Missing JSON file should be created from template."""
config_file = tmp_path / "new_config.json"
config_file = tmp_path / 'new_config.json'
json_file = JSONConfigFile(
str(config_file),
@@ -205,7 +209,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_json_save(self, tmp_path):
"""JSON config can be saved."""
config_file = tmp_path / "save_test.json"
config_file = tmp_path / 'save_test.json'
json_file = JSONConfigFile(
str(config_file),
@@ -226,7 +230,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_load(self, tmp_path):
"""ConfigManager loads config correctly."""
config_file = tmp_path / "manager_test.yaml"
config_file = tmp_path / 'manager_test.yaml'
config_file.write_text('name: managed_app\nversion: "1.0"\n')
yaml_file = YAMLConfigFile(
@@ -243,7 +247,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_dump(self, tmp_path):
"""ConfigManager can dump config."""
config_file = tmp_path / "dump_test.yaml"
config_file = tmp_path / 'dump_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -260,7 +264,7 @@ class TestConfigManager:
def test_config_manager_dump_sync(self, tmp_path):
"""ConfigManager can dump config synchronously."""
config_file = tmp_path / "sync_dump.yaml"
config_file = tmp_path / 'sync_dump.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -280,7 +284,7 @@ class TestConfigExists:
def test_yaml_exists_true(self, tmp_path):
"""exists() returns True for existing file."""
config_file = tmp_path / "exists.yaml"
config_file = tmp_path / 'exists.yaml'
config_file.write_text('name: test')
yaml_file = YAMLConfigFile(str(config_file), template_data={})
@@ -288,14 +292,14 @@ class TestConfigExists:
def test_yaml_exists_false(self, tmp_path):
"""exists() returns False for missing file."""
config_file = tmp_path / "missing.yaml"
config_file = tmp_path / 'missing.yaml'
yaml_file = YAMLConfigFile(str(config_file), template_data={})
assert yaml_file.exists() is False
def test_json_exists_true(self, tmp_path):
"""exists() returns True for existing JSON file."""
config_file = tmp_path / "exists.json"
config_file = tmp_path / 'exists.json'
config_file.write_text('{}')
json_file = JSONConfigFile(str(config_file), template_data={})
@@ -303,7 +307,7 @@ class TestConfigExists:
def test_json_exists_false(self, tmp_path):
"""exists() returns False for missing JSON file."""
config_file = tmp_path / "missing.json"
config_file = tmp_path / 'missing.json'
json_file = JSONConfigFile(str(config_file), template_data={})
assert json_file.exists() is False
assert json_file.exists() is False
+1 -1
View File
@@ -1 +1 @@
"""Core module unit tests."""
"""Core module unit tests."""
@@ -4,6 +4,7 @@ Tests cover:
- _get_positive_int_config() validation
- _get_positive_float_config() validation
"""
from __future__ import annotations
from unittest.mock import Mock
@@ -188,4 +189,4 @@ class TestGetPositiveFloatConfig:
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
assert result == 1.5
mock_logger.warning.assert_called_once()
mock_logger.warning.assert_called_once()
@@ -27,6 +27,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert result == []
@@ -46,6 +47,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert 'requests' in result
@@ -61,6 +63,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
# Should include all required_deps keys
@@ -107,6 +110,7 @@ class TestPrecheckPluginDeps:
with patch('os.path.exists', return_value=False):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_not_called()
@@ -129,6 +133,7 @@ class TestPrecheckPluginDeps:
with patch('os.listdir', side_effect=mock_listdir):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])
+4 -10
View File
@@ -7,6 +7,7 @@ Tests cover:
- Dict type skipping
- Missing key creation
"""
from __future__ import annotations
import os
@@ -248,15 +249,8 @@ class TestApplyEnvOverridesToConfig:
"""Test multiple env vars applied in order."""
load_config = get_load_config_module()
cfg = {
'system': {'name': 'default', 'enable': True},
'concurrency': {'pipeline': 5}
}
env = {
'SYSTEM__NAME': 'custom',
'SYSTEM__ENABLE': 'false',
'CONCURRENCY__PIPELINE': '10'
}
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}}
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'}
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
@@ -287,4 +281,4 @@ class TestApplyEnvOverridesToConfig:
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
+1 -1
View File
@@ -175,4 +175,4 @@ class TestPreregisteredStages:
pass
for key in preregistered_stages:
assert isinstance(key, str)
assert isinstance(key, str)
+23 -23
View File
@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import_isolation to break circular import chains.
"""
from __future__ import annotations
import pytest
@@ -19,15 +20,17 @@ from typing import Generator
class MockLifecycleControlScopeEnum:
"""Mock enum value for LifecycleControlScope with .value attribute."""
def __init__(self, value: str):
self.value = value
def __repr__(self):
return f"LifecycleControlScope.{self.value.upper()}"
return f'LifecycleControlScope.{self.value.upper()}'
class MockLifecycleControlScope:
"""Mock enum for LifecycleControlScope."""
APPLICATION = MockLifecycleControlScopeEnum('application')
PLATFORM = MockLifecycleControlScopeEnum('platform')
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
@@ -40,17 +43,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
# Mock modules that cause circular imports
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
mock_app = MagicMock()
mock_importutil = MagicMock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
mock_http_controller = MagicMock()
mock_rag_mgr = MagicMock()
mocks = {
'langbot.pkg.core.entities': mock_entities,
'langbot.pkg.core.app': mock_app,
@@ -58,26 +61,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
'langbot.pkg.utils.importutil': mock_importutil,
}
# Save original state
saved = {}
for name in mocks:
if name in sys.modules:
saved[name] = sys.modules[name]
# Clear taskmgr to force re-import
taskmgr_name = 'langbot.pkg.core.taskmgr'
if taskmgr_name in sys.modules:
saved[taskmgr_name] = sys.modules[taskmgr_name]
try:
# Apply mocks
for name, module in mocks.items():
sys.modules[name] = module
# Clear taskmgr
sys.modules.pop(taskmgr_name, None)
yield
finally:
# Restore
@@ -86,7 +89,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
sys.modules[name] = saved[name]
else:
sys.modules.pop(name, None)
if taskmgr_name in saved:
sys.modules[taskmgr_name] = saved[taskmgr_name]
else:
@@ -97,6 +100,7 @@ def get_taskmgr_classes():
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
return TaskContext, TaskWrapper, AsyncTaskManager
@@ -194,9 +198,10 @@ class TestTaskContext:
"""Test TaskContext.placeholder() returns singleton."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext
# Reset global placeholder
import langbot.pkg.core.taskmgr as taskmgr_module
taskmgr_module.placeholder_context = None
ctx1 = TaskContext.placeholder()
@@ -269,7 +274,8 @@ class TestTaskWrapper:
return 'result'
wrapper = TaskWrapper(
mock_app, immediate_coro(),
mock_app,
immediate_coro(),
name='test_task',
label='Test Task',
)
@@ -414,7 +420,7 @@ class TestAsyncTaskManager:
async def test_cancel_by_scope(self):
"""Test cancel_by_scope cancels matching tasks."""
_, _, AsyncTaskManager = get_taskmgr_classes()
mock_app = create_mock_app()
manager = AsyncTaskManager(mock_app)
@@ -422,16 +428,10 @@ class TestAsyncTaskManager:
await asyncio.sleep(10)
# Create task with APPLICATION scope
w1 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.APPLICATION]
)
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION])
# Create task with different scope
w2 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.PIPELINE]
)
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE])
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)
+41 -41
View File
@@ -15,68 +15,68 @@ class TestI18nString:
def test_create_with_english_only(self):
"""Create I18nString with only English."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
assert i18n.en_US == "Hello"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans is None
def test_create_with_multiple_languages(self):
"""Create I18nString with multiple languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
)
assert i18n.en_US == "Hello"
assert i18n.zh_Hans == "你好"
assert i18n.zh_Hant == "你好"
assert i18n.ja_JP == "こんにちは"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans == '你好'
assert i18n.zh_Hant == '你好'
assert i18n.ja_JP == 'こんにちは'
def test_to_dict_with_english_only(self):
"""to_dict returns only non-None fields."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
result = i18n.to_dict()
assert result == {"en_US": "Hello"}
assert result == {'en_US': 'Hello'}
def test_to_dict_with_multiple_languages(self):
"""to_dict returns all non-None fields."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
en_US='Hello',
zh_Hans='你好',
)
result = i18n.to_dict()
assert result == {"en_US": "Hello", "zh_Hans": "你好"}
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'}
def test_to_dict_excludes_none(self):
"""to_dict excludes None values."""
i18n = I18nString(
en_US="Hello",
en_US='Hello',
zh_Hans=None,
ja_JP="こんにちは",
ja_JP='こんにちは',
)
result = i18n.to_dict()
assert "zh_Hans" not in result
assert "en_US" in result
assert "ja_JP" in result
assert 'zh_Hans' not in result
assert 'en_US' in result
assert 'ja_JP' in result
def test_to_dict_all_languages(self):
"""to_dict with all supported languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
th_TH="สวัสดี",
vi_VN="Xin chào",
es_ES="Hola",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
th_TH='สวัสดี',
vi_VN='Xin chào',
es_ES='Hola',
)
result = i18n.to_dict()
@@ -92,30 +92,30 @@ class TestMetadata:
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test Component"),
name='test-component',
label=I18nString(en_US='Test Component'),
)
assert metadata.name == "test-component"
assert metadata.label.en_US == "Test Component"
assert metadata.name == 'test-component'
assert metadata.label.en_US == 'Test Component'
def test_create_with_all_fields(self):
"""Create Metadata with all optional fields."""
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test"),
description=I18nString(en_US="A test component"),
version="1.0.0",
icon="test-icon",
author="Test Author",
repository="https://github.com/test/repo",
name='test-component',
label=I18nString(en_US='Test'),
description=I18nString(en_US='A test component'),
version='1.0.0',
icon='test-icon',
author='Test Author',
repository='https://github.com/test/repo',
)
assert metadata.version == "1.0.0"
assert metadata.icon == "test-icon"
assert metadata.author == "Test Author"
assert metadata.version == '1.0.0'
assert metadata.icon == 'test-icon'
assert metadata.author == 'Test Author'
class TestComponentManifest:
@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import isolation to break circular import chains.
"""
from __future__ import annotations
import sys
@@ -86,6 +87,7 @@ def get_database_module():
"""Get database module with import isolation."""
with isolated_database_import():
from langbot.pkg.persistence import database
return database
@@ -198,4 +200,4 @@ class TestManagerClassDecorator:
# Create instance to test method (with mock app)
mock_app = Mock()
instance = ManagerWithMethods(mock_app)
assert instance.custom_method() == 'test_value'
assert instance.custom_method() == 'test_value'
@@ -4,6 +4,7 @@ Tests cover:
- execute_async() with mock database
- get_db_engine() with mock database manager
"""
from __future__ import annotations
import pytest
@@ -85,7 +86,7 @@ class TestExecuteAsync:
mock_db.get_engine = Mock(return_value=mock_engine)
mgr.db = mock_db
result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
result = await mgr.execute_async(sqlalchemy.text('SELECT 1'))
# Verify result is the same object returned by execute
assert result is mock_result
@@ -152,4 +153,4 @@ class TestSerializeModelEdgeCases:
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
# Result should be empty dict when all columns masked
assert result == {}
assert result == {}
@@ -5,6 +5,7 @@ Tests cover:
- datetime conversion to isoformat
- masked_columns exclusion
"""
from __future__ import annotations
import datetime
+14 -14
View File
@@ -49,7 +49,7 @@ class TestPendingMessage:
"""PendingMessage should be created with correct fields."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -88,7 +88,7 @@ class TestSessionBuffer:
"""SessionBuffer should accept initial messages."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("hello")
chain2 = text_chain("world")
chain1 = text_chain('hello')
chain2 = text_chain('world')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
# Should contain both messages with separator
merged_str = str(merged.message_chain)
assert "hello" in merged_str
assert "world" in merged_str
assert 'hello' in merged_str
assert 'world' in merged_str
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("first")
chain2 = text_chain("second")
chain1 = text_chain('first')
chain2 = text_chain('second')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
+34 -4
View File
@@ -15,6 +15,7 @@ from tests.factories import FakeApp
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -36,9 +37,11 @@ def mock_circular_import_chain():
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
@@ -70,9 +73,12 @@ def mock_event_ctx():
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
@@ -87,6 +93,7 @@ def get_chat_handler():
global _chat_handler_module
if _chat_handler_module is None:
from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module
@@ -96,12 +103,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class."""
@@ -188,9 +197,11 @@ class TestChatMessageHandlerReal:
class QuickRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='ok')
@@ -222,9 +233,11 @@ class TestChatMessageHandlerReal:
class SingleRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='response')
@@ -262,9 +275,11 @@ class TestChatHandlerStreaming:
class StreamRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True)
@@ -303,14 +318,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class FailingRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('API error')
yield
@@ -346,14 +366,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class ErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('Custom error')
yield
@@ -386,14 +411,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class HideErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise RuntimeError('hidden')
yield
@@ -433,4 +463,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result
+22 -24
View File
@@ -67,7 +67,11 @@ def make_pipeline_config(**overrides):
for key, value in overrides.items():
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
if (
sub_key in base_config[key]
and isinstance(base_config[key][sub_key], dict)
and isinstance(sub_value, dict)
):
base_config[key][sub_key].update(sub_value)
else:
base_config[key][sub_key] = sub_value
@@ -141,7 +145,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -163,7 +167,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
# Empty message chain
query = text_query("")
query = text_query('')
query.message_chain = platform_message.MessageChain([])
query.pipeline_config = pipeline_config
@@ -185,7 +189,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query(" ") # Only whitespace
query = text_query(' ') # Only whitespace
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -234,7 +238,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -266,7 +270,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -294,7 +298,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("http://example.com")
query = text_query('http://example.com')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -322,7 +326,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("normal message")
query = text_query('normal message')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -343,7 +347,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -368,12 +372,10 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Add a response message
query.resp_messages = [
provider_message.Message(role='assistant', content='Hello back!')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -398,11 +400,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='Response')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -422,7 +422,7 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-string content - use model_construct to bypass validation
# The actual content type could be a list of ContentElement objects
@@ -450,11 +450,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='')
]
query.resp_messages = [provider_message.Message(role='assistant', content='')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
@@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect:
await stage.initialize(pipeline_config)
query = text_query("normal message without prefix")
query = text_query('normal message without prefix')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -56,6 +57,7 @@ def mock_event_ctx():
@pytest.fixture
def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute(
text: str | None = 'ok',
error: str | None = None,
@@ -71,7 +73,9 @@ def mock_execute_factory():
ret.image_base64 = image_base64
ret.file_url = file_url
yield ret
return mock_execute
return _create_execute
@@ -86,6 +90,7 @@ def get_command_handler():
global _command_handler_module
if _command_handler_module is None:
from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module
@@ -95,12 +100,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal:
"""Tests for real CommandHandler class."""
@@ -127,6 +134,7 @@ class TestCommandHandlerReal:
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
executed_commands = []
async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text)
ret = Mock()
@@ -334,8 +342,7 @@ class TestCommandHandlerReal:
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:',
image_url='https://example.com/image.png'
text='Here is the image:', image_url='https://example.com/image.png'
)
handler = command.CommandHandler(fake_app)
@@ -393,4 +400,4 @@ class TestCommandHandlerHelper:
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result
+19 -27
View File
@@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="very long response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="short response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-Plain component (Image)
query.resp_message_chain = [
platform_message.MessageChain([
platform_message.Plain(text="short"),
platform_message.Image(url="https://example.com/img.png")
])
platform_message.MessageChain(
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
)
]
result = await stage.process(query, 'LongTextProcessStage')
@@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess:
stage = longtext.LongTextProcessStage(app)
stage.strategy_impl = AsyncMock()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
query.resp_message_chain = []
@@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess:
assert result.new_query is query
stage.strategy_impl.process.assert_not_called()
class TestForwardStrategy:
"""Tests for ForwardComponentStrategy."""
@@ -260,7 +256,7 @@ class TestForwardStrategy:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Create a mock adapter with bot_account_id
mock_adapter = Mock()
@@ -268,10 +264,8 @@ class TestForwardStrategy:
query.adapter = mock_adapter
# Long text exceeding threshold
long_text = "This is a very long response that exceeds the threshold"
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=long_text)])
]
long_text = 'This is a very long response that exceeds the threshold'
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -297,13 +291,13 @@ class TestForwardStrategy:
await strat.initialize()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config()
mock_adapter = Mock()
mock_adapter.bot_account_id = '12345'
query.adapter = mock_adapter
components = await strat.process("test message", query)
components = await strat.process('test message', query)
assert len(components) == 1
assert isinstance(components[0], platform_message.Forward)
@@ -326,14 +320,12 @@ class TestLongTextThreshold:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Text below threshold
short_text = "x" * (threshold - 1)
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=short_text)])
]
short_text = 'x' * (threshold - 1)
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
result = await stage.process(query, 'LongTextProcessStage')
+7 -7
View File
@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
# Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = []
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
trun = trun_cls(app)
break
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),
+15 -12
View File
@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello world")
query = text_query('hello world')
result = await stage.process(query, 'PreProcessor')
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("test message")
query = text_query('test message')
result = await stage.process(query, 'PreProcessor')
@@ -194,13 +194,16 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app)
# Image query with base64
query = image_query(text="look at this", url=None)
query = image_query(text='look at this', url=None)
# Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([
platform_message.Plain(text="look at this"),
platform_message.Image(base64="data:image/png;base64,abc123"),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text='look at this'),
platform_message.Image(base64='data:image/png;base64,abc123'),
]
)
query.message_chain = chain
result = await stage.process(query, 'PreProcessor')
@@ -238,7 +241,7 @@ class TestPreProcessorImageSegment:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = image_query(text="describe this")
query = image_query(text='describe this')
result = await stage.process(query, 'PreProcessor')
@@ -276,7 +279,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
# Set pipeline config with primary model
query.pipeline_config = {
@@ -335,7 +338,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = {
'ai': {
@@ -384,7 +387,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello", sender_id=67890)
query = text_query('hello', sender_id=67890)
result = await stage.process(query, 'PreProcessor')
@@ -421,7 +424,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = group_text_query("hello", group_id=99999)
query = group_text_query('hello', group_id=99999)
result = await stage.process(query, 'PreProcessor')
+18 -58
View File
@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
'safety': {
'rate-limit': {
'window-length': 60, # 60 seconds window
'limitation': 10, # 10 requests per window
'limitation': 10, # 10 requests per window
'strategy': 'drop',
}
}
@@ -75,11 +75,9 @@ class TestFixedWindowAlgo:
# Make requests within limit
for i in range(10):
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
)
assert result is True, f"Request {i+1} should be allowed"
assert result is True, f'Request {i + 1} should be allowed'
@pytest.mark.asyncio
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -91,20 +89,12 @@ class TestFixedWindowAlgo:
# Exhaust the limit
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Next request should be denied
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
assert result is False, "Request exceeding limit should be denied"
assert result is False, 'Request exceeding limit should be denied'
@pytest.mark.asyncio
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -116,20 +106,14 @@ class TestFixedWindowAlgo:
# Exhaust limit for session 1
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session1'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
# Session 2 should still have its own limit
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session2'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
)
assert result is True, "Different session should have independent limit"
assert result is True, 'Different session should have independent limit'
@pytest.mark.asyncio
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
@@ -150,19 +134,11 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request allowed
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result1 is True
# Second request denied
result2 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result2 is False
@pytest.mark.asyncio
@@ -174,11 +150,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request creates container
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
expected_key = 'LauncherTypes.PERSON_12345'
@@ -230,7 +202,7 @@ class TestFixedWindowAlgo:
# New request should be allowed (new window)
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
assert result is True, "New window should allow new requests"
assert result is True, 'New window should allow new requests'
@pytest.mark.asyncio
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
@@ -256,29 +228,21 @@ class TestFixedWindowAlgo:
# First request allowed
start_time = time.time()
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
assert result1 is True
# Exhaust limit
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
# Third request should wait and then succeed
result3 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
elapsed = time.time() - start_time
assert result3 is True, "After wait, request should succeed"
assert result3 is True, 'After wait, request should succeed'
# Should have waited approximately until next window
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
# Note: This is a timing-sensitive test, so we use a generous tolerance
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
@pytest.mark.asyncio
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -289,11 +253,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# release_access is empty in current implementation
await algo.release_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Should not raise or change state
assert 'person_12345' not in algo.containers
+25 -27
View File
@@ -55,7 +55,7 @@ def make_session():
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
query.resp_message_chain = []
results = []
@@ -125,7 +123,7 @@ class TestResponseWrapperCommand:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -133,7 +131,7 @@ class TestResponseWrapperCommand:
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
)
query.resp_messages = [command_resp]
@@ -163,7 +161,7 @@ class TestResponseWrapperPlugin:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -171,7 +169,7 @@ class TestResponseWrapperPlugin:
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
)
query.resp_messages = [plugin_resp]
@@ -211,17 +209,17 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello back!"
assistant_resp.content = 'Hello back!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
)
query.resp_messages = [assistant_resp]
@@ -247,7 +245,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -292,7 +290,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -303,10 +301,10 @@ class TestResponseWrapperAssistant:
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Processing..."
assistant_resp.content = 'Processing...'
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
)
query.resp_messages = [assistant_resp]
@@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello!"
assistant_resp.content = 'Hello!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
)
query.resp_messages = [assistant_resp]
@@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply:
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
@@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Default reply"
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
@@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply:
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert "Custom reply" in str(chain)
assert 'Custom reply' in str(chain)
class TestResponseWrapperVariables:
@@ -452,7 +450,7 @@ class TestResponseWrapperVariables:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
@@ -460,10 +458,10 @@ class TestResponseWrapperVariables:
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello"
assistant_resp.content = 'Hello'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
)
query.resp_messages = [assistant_resp]
@@ -0,0 +1,146 @@
"""Unit tests for ResponseWrapper outbound-attachment helpers.
Covers the sandbox -> user attachment path added for the Box attachment
round-trip:
* ``_is_final_assistant_message`` only the terminal, tool-call-free assistant
message (or a final MessageChunk) should trigger collection.
* ``_append_outbound_attachments`` collects sandbox outbox files exactly once
per query and maps each descriptor to the right platform component, swallowing
collection errors.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from langbot.pkg.pipeline.wrapper.wrapper import ResponseWrapper
def _make_wrapper(box_service) -> ResponseWrapper:
app = SimpleNamespace(logger=Mock())
wrapper = ResponseWrapper.__new__(ResponseWrapper)
wrapper.ap = app
return wrapper
def _make_query():
return SimpleNamespace(variables={})
def test_is_final_assistant_message_plain_assistant():
wrapper = _make_wrapper(box_service=None)
msg = provider_message.Message(role='assistant', content='done')
assert wrapper._is_final_assistant_message(msg) is True
def test_is_final_assistant_message_rejects_non_assistant():
wrapper = _make_wrapper(box_service=None)
msg = provider_message.Message(role='tool', content='{}')
assert wrapper._is_final_assistant_message(msg) is False
def test_is_final_assistant_message_rejects_tool_call_round():
wrapper = _make_wrapper(box_service=None)
msg = provider_message.Message(
role='assistant',
content='calling',
tool_calls=[
provider_message.ToolCall(
id='c1',
type='function',
function=provider_message.FunctionCall(name='exec', arguments='{}'),
)
],
)
assert wrapper._is_final_assistant_message(msg) is False
def test_is_final_assistant_message_non_final_chunk():
wrapper = _make_wrapper(box_service=None)
chunk = provider_message.MessageChunk(role='assistant', content='partial', is_final=False)
assert wrapper._is_final_assistant_message(chunk) is False
final_chunk = provider_message.MessageChunk(role='assistant', content='partial', is_final=True)
assert wrapper._is_final_assistant_message(final_chunk) is True
@pytest.mark.asyncio
async def test_append_outbound_attachments_maps_each_type():
box_service = SimpleNamespace(
available=True,
collect_outbound_attachments=AsyncMock(
return_value=[
{'type': 'Image', 'base64': 'data:image/png;base64,iVBORw0K'},
{'type': 'Voice', 'base64': 'data:audio/wav;base64,UklGRg=='},
{'type': 'File', 'name': 'report.xlsx', 'base64': 'data:app;base64,UEsDBA=='},
]
),
)
wrapper = _make_wrapper(box_service)
wrapper.ap.box_service = box_service
query = _make_query()
chain = platform_message.MessageChain([])
await wrapper._append_outbound_attachments(query, chain)
kinds = [type(c).__name__ for c in chain]
assert kinds == ['Image', 'Voice', 'File']
assert query.variables['_sandbox_outbound_collected'] is True
# File keeps its name
file_comp = chain[2]
assert getattr(file_comp, 'name', None) == 'report.xlsx'
@pytest.mark.asyncio
async def test_append_outbound_attachments_runs_once_per_query():
box_service = SimpleNamespace(
available=True,
collect_outbound_attachments=AsyncMock(return_value=[]),
)
wrapper = _make_wrapper(box_service)
wrapper.ap.box_service = box_service
query = _make_query()
query.variables['_sandbox_outbound_collected'] = True
chain = platform_message.MessageChain([])
await wrapper._append_outbound_attachments(query, chain)
box_service.collect_outbound_attachments.assert_not_awaited()
assert len(chain) == 0
@pytest.mark.asyncio
async def test_append_outbound_attachments_noop_without_box_service():
wrapper = _make_wrapper(box_service=None)
wrapper.ap.box_service = None
query = _make_query()
chain = platform_message.MessageChain([])
await wrapper._append_outbound_attachments(query, chain)
assert len(chain) == 0
# not marked collected, since service is unavailable
assert '_sandbox_outbound_collected' not in query.variables
@pytest.mark.asyncio
async def test_append_outbound_attachments_swallows_collection_error():
box_service = SimpleNamespace(
available=True,
collect_outbound_attachments=AsyncMock(side_effect=RuntimeError('boom')),
)
wrapper = _make_wrapper(box_service)
wrapper.ap.box_service = box_service
query = _make_query()
chain = platform_message.MessageChain([])
# must not raise
await wrapper._append_outbound_attachments(query, chain)
assert len(chain) == 0
wrapper.ap.logger.warning.assert_called_once()
@@ -0,0 +1,92 @@
"""Unit tests for WebSocketAdapter._process_image_components.
The web debug client uploads Image / Voice / File components carrying a storage
key in ``path``. This helper resolves each to a base64 data URI (so multimodal
LLM input and the Box sandbox inbox have usable bytes), then deletes the
consumed storage object and clears ``path``. Covers mimetype selection per
type and graceful error handling.
"""
from __future__ import annotations
import base64
from unittest.mock import AsyncMock, Mock
import pytest
from langbot.pkg.platform.sources.websocket_adapter import WebSocketAdapter
def _make_adapter(load_return=b'hello', load_side_effect=None):
provider = Mock()
provider.load = AsyncMock(return_value=load_return, side_effect=load_side_effect)
provider.delete = AsyncMock()
ap = Mock()
ap.storage_mgr.storage_provider = provider
logger = Mock()
logger.error = AsyncMock()
# WebSocketAdapter is a pydantic model; bypass full __init__/validation.
adapter = WebSocketAdapter.model_construct(ap=ap, logger=logger)
return adapter, provider
@pytest.mark.asyncio
async def test_image_jpeg_mimetype_and_cleanup():
adapter, provider = _make_adapter(load_return=b'\xff\xd8\xff')
chain = [{'type': 'Image', 'path': 'storage://abc/photo.jpg'}]
await adapter._process_image_components(chain)
expected_b64 = base64.b64encode(b'\xff\xd8\xff').decode('utf-8')
assert chain[0]['base64'] == f'data:image/jpeg;base64,{expected_b64}'
assert chain[0]['path'] == '' # consumed
provider.delete.assert_awaited_once_with('storage://abc/photo.jpg')
@pytest.mark.asyncio
async def test_image_defaults_to_png():
adapter, _ = _make_adapter()
chain = [{'type': 'Image', 'path': 'storage://abc/blob'}]
await adapter._process_image_components(chain)
assert chain[0]['base64'].startswith('data:image/png;base64,')
@pytest.mark.asyncio
async def test_voice_uses_guessed_or_wav_mimetype():
adapter, _ = _make_adapter()
chain = [{'type': 'Voice', 'path': 'storage://abc/clip.wav'}]
await adapter._process_image_components(chain)
assert chain[0]['base64'].startswith('data:audio/')
@pytest.mark.asyncio
async def test_file_uses_octet_stream_fallback():
adapter, _ = _make_adapter()
chain = [{'type': 'File', 'path': 'storage://abc/unknownblob'}]
await adapter._process_image_components(chain)
assert chain[0]['base64'].startswith('data:application/octet-stream;base64,')
@pytest.mark.asyncio
async def test_skips_components_without_path_or_unknown_type():
adapter, provider = _make_adapter()
chain = [
{'type': 'Image', 'path': ''}, # no path
{'type': 'Plain', 'path': 'storage://abc/x'}, # not a file component
{'type': 'At', 'target': '123'}, # no path key at all
]
await adapter._process_image_components(chain)
provider.load.assert_not_awaited()
assert 'base64' not in chain[0]
assert 'base64' not in chain[1]
@pytest.mark.asyncio
async def test_load_failure_is_logged_not_raised():
adapter, _ = _make_adapter(load_side_effect=RuntimeError('storage down'))
chain = [{'type': 'File', 'path': 'storage://abc/doc.pdf'}]
# must not raise
await adapter._process_image_components(chain)
assert 'base64' not in chain[0]
adapter.logger.error.assert_awaited_once()
@@ -6,6 +6,7 @@ Tests cover:
- RAG methods (ingest, retrieve, schema)
- Disabled plugin early returns
"""
from __future__ import annotations
import pytest
@@ -86,16 +87,12 @@ class TestListPlugins:
return_value=[
{
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Command'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}],
'debug': False,
},
{
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Tool'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}],
'debug': False,
},
]
@@ -127,9 +124,7 @@ class TestListPlugins:
},
]
)
connector.ap.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(__iter__=lambda self: iter([]))
)
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([])))
result = await connector.list_plugins()
@@ -230,7 +225,8 @@ class TestCallParser:
)
connector.handler.parse_document.assert_called_once_with(
'author', 'parser',
'author',
'parser',
{'mime_type': 'text/plain', 'filename': 'test.txt'},
b'file content',
)
@@ -251,9 +247,7 @@ class TestRAGMethods:
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
connector.handler.rag_ingest_document.assert_called_once_with(
'author', 'engine', {'file': 'test.pdf'}
)
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'})
assert result['status'] == 'success'
@pytest.mark.asyncio
@@ -264,14 +258,16 @@ class TestRAGMethods:
connector.handler = AsyncMock()
connector.handler.retrieve_knowledge = AsyncMock(
return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
return_value={
'results': [
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
]
}
)
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
connector.handler.retrieve_knowledge.assert_called_once_with(
'author', 'engine', '', {'query': 'test'}
)
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'})
assert result == {
'results': [
{
@@ -290,9 +286,7 @@ class TestRAGMethods:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}})
result = await connector.get_rag_creation_schema('author/engine')
@@ -326,9 +320,7 @@ class TestRAGMethods:
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
connector.handler.rag_on_kb_create.assert_called_once_with(
'author', 'engine', 'kb-uuid', {'model': 'test'}
)
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'})
@pytest.mark.asyncio
async def test_rag_on_kb_delete(self):
@@ -354,9 +346,7 @@ class TestRAGMethods:
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
connector.handler.rag_delete_document.assert_called_once_with(
'author', 'engine', 'doc-uuid', 'kb-uuid'
)
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid')
assert result is True
@@ -446,9 +436,7 @@ class TestGetPluginInfo:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_plugin_info = AsyncMock(
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
)
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}})
result = await connector.get_plugin_info('author', 'plugin')
@@ -470,9 +458,7 @@ class TestSetPluginConfig:
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
connector.handler.set_plugin_config.assert_called_once_with(
'author', 'plugin', {'setting': 'value'}
)
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'})
class TestPingPluginRuntime:
@@ -3,6 +3,7 @@
Tests cover:
- _parse_plugin_id() parsing and validation
"""
from __future__ import annotations
import pytest
+5 -4
View File
@@ -6,6 +6,7 @@ Tests cover:
- Handling missing requirements.txt
- Handling empty/malformed requirements.txt
"""
from __future__ import annotations
import zipfile
@@ -82,13 +83,13 @@ class TestExtractDepsMetadata:
"""Test that comments and empty lines are filtered."""
connector_instance = create_mock_connector()
requirements = '''# This is a comment
requirements = """# This is a comment
requests>=2.0
# Another comment
flask==1.0
numpy'''
numpy"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()
@@ -147,9 +148,9 @@ numpy'''
"""Test handling requirements.txt with only comments."""
connector_instance = create_mock_connector()
requirements = '''# Comment 1
requirements = """# Comment 1
# Comment 2
# Comment 3'''
# Comment 3"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()
+30 -27
View File
@@ -40,11 +40,13 @@ class TestHandlerQueryVariables:
"""Test set_query_var returns error when query not found."""
runtime_handler = make_handler(mock_app)
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -58,11 +60,13 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code == 0
assert mock_query.variables['test_var'] == 'test_value'
@@ -76,10 +80,12 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'existing_var',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'existing_var',
}
)
assert response.code == 0
assert response.data == {'value': 'existing_value'}
@@ -93,9 +99,11 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
'query_id': 'test-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value](
{
'query_id': 'test-query',
}
)
assert response.code == 0
assert response.data == {'vars': mock_query.variables}
@@ -108,7 +116,7 @@ class TestHandlerRagErrorResponse:
"""Test basic error response creation."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = Exception("test error")
error = Exception('test error')
response = _make_rag_error_response(error, 'TestError')
# ActionResponse is a pydantic model, check message field
@@ -120,13 +128,8 @@ class TestHandlerRagErrorResponse:
"""Test error response with extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = ValueError("invalid input")
response = _make_rag_error_response(
error,
'ValidationError',
field='name',
value='test'
)
error = ValueError('invalid input')
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test')
assert 'ValidationError' in response.message
assert 'field=name' in response.message
@@ -137,7 +140,7 @@ class TestHandlerRagErrorResponse:
"""Test error response includes exception type."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = RuntimeError("connection failed")
error = RuntimeError('connection failed')
response = _make_rag_error_response(error, 'ConnectionError')
assert 'RuntimeError' in response.message
@@ -148,7 +151,7 @@ class TestHandlerRagErrorResponse:
"""Test error response with no extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = KeyError("missing_key")
error = KeyError('missing_key')
response = _make_rag_error_response(error, 'LookupError')
# No context parts means no brackets
+54 -42
View File
@@ -47,12 +47,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -82,12 +84,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 3
@@ -161,9 +165,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'new')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new'))
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -203,9 +205,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'x')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x'))
assert response.code != 0
assert '1 > 0 bytes' in response.message
@@ -228,10 +228,12 @@ class TestGetPluginSettings:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -255,10 +257,12 @@ class TestGetPluginSettings:
)
app.persistence_mgr.execute_async.return_value = make_result(setting)
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -286,11 +290,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code == 0
assert response.data == {
@@ -303,11 +309,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code != 0
assert 'Storage with key test-key not found' in response.message
@@ -329,9 +337,11 @@ class TestHandlerQueryLookup:
"""Query-bound actions return error when query_id is not cached."""
runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'nonexistent-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'nonexistent-query',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -343,9 +353,11 @@ class TestHandlerQueryLookup:
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
app.query_pool.cached_queries['existing-query'] = query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'existing-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'existing-query',
}
)
assert response.code == 0
assert response.data == {'bot_uuid': 'test-bot-uuid'}
@@ -4,6 +4,7 @@ Tests cover:
- _make_rag_error_response() helper function
- RuntimeConnectionHandler cleanup_plugin_data method
"""
from __future__ import annotations
import pytest
@@ -23,7 +24,7 @@ class TestMakeRagErrorResponse:
"""Test basic error response creation."""
handler = get_handler_module()
error = ValueError("test error message")
error = ValueError('test error message')
result = handler._make_rag_error_response(error, 'TestError')
# ActionResponse.error() returns code=1 (error status)
@@ -36,7 +37,7 @@ class TestMakeRagErrorResponse:
"""Test that error type is included in message."""
handler = get_handler_module()
error = RuntimeError("something went wrong")
error = RuntimeError('something went wrong')
result = handler._make_rag_error_response(error, 'VectorStoreError')
assert '[VectorStoreError/RuntimeError]' in result.message
@@ -45,7 +46,7 @@ class TestMakeRagErrorResponse:
"""Test that extra context fields are included."""
handler = get_handler_module()
error = Exception("embedding failed")
error = Exception('embedding failed')
result = handler._make_rag_error_response(
error,
'EmbeddingError',
@@ -71,7 +72,7 @@ class TestMakeRagErrorResponse:
"""Test multiple context fields are comma separated."""
handler = get_handler_module()
error = IOError("file not found")
error = IOError('file not found')
result = handler._make_rag_error_response(
error,
'FileServiceError',
@@ -119,9 +120,7 @@ class TestCleanupPluginData:
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
handler_instance.ap = mock_app
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
handler_instance, 'author', 'plugin-name'
)
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name')
# Should have at least 2 calls: one for settings, one for binary storage
assert mock_app.persistence_mgr.execute_async.call_count >= 2
assert mock_app.persistence_mgr.execute_async.call_count >= 2
+7 -4
View File
@@ -88,7 +88,10 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
return provider_message.Message(
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
)
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results."""
@@ -135,8 +138,10 @@ def mock_app_for_modelmgr():
# Fake persistence manager - returns empty results by default
app.persistence_mgr = SimpleNamespace()
async def default_execute(query):
return _make_mock_result([])
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
# Fake discover engine
@@ -165,9 +170,7 @@ def fake_requester_registry(mock_app_for_modelmgr):
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
app.discover.get_components_by_kind = Mock(
return_value=[fake_component, another_component]
)
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component])
model_mgr = ModelManager(app)
return model_mgr
@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
with pytest.raises(DifyAPIError, match='不支持'):
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app
assert runner.ap == mock_app
@@ -0,0 +1,93 @@
"""Unit tests for LiteLLMRequester._convert_messages.
Focus: the content-part normalization that (a) converts image_base64 parts to
the OpenAI image_url shape and (b) drops non-image file parts (file_base64 /
file_url) which OpenAI-compatible chat models reject. The latter is essential
for Voice/File attachments including ones replayed from conversation history
since the agent consumes their bytes via the sandbox, not the model payload.
"""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester
def _make_requester() -> LiteLLMRequester:
# _convert_messages does not touch instance config, so bypass __init__.
return LiteLLMRequester.__new__(LiteLLMRequester)
def test_convert_messages_drops_file_base64_part():
req = _make_requester()
msg = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('analyze this audio'),
provider_message.ContentElement.from_file_base64('data:audio/wav;base64,AAAA', 'voice.wav'),
],
)
out = req._convert_messages([msg])
parts = out[0]['content']
types = [p.get('type') for p in parts]
assert 'file_base64' not in types
assert types == ['text']
assert parts[0]['text'] == 'analyze this audio'
def test_convert_messages_drops_file_url_part():
req = _make_requester()
msg = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('here is a doc'),
provider_message.ContentElement.from_file_url('http://example.com/report.xlsx', 'report.xlsx'),
],
)
out = req._convert_messages([msg])
types = [p.get('type') for p in out[0]['content']]
assert types == ['text']
def test_convert_messages_keeps_image_and_converts_to_image_url():
req = _make_requester()
msg = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('look'),
provider_message.ContentElement.from_image_base64('data:image/png;base64,AAAA'),
],
)
out = req._convert_messages([msg])
parts = out[0]['content']
types = [p.get('type') for p in parts]
# image is preserved and reshaped to the OpenAI image_url form
assert types == ['text', 'image_url']
img_part = parts[1]
assert img_part['image_url'] == {'url': 'data:image/png;base64,AAAA'}
assert 'image_base64' not in img_part
def test_convert_messages_mixed_history_strips_only_files():
req = _make_requester()
# Simulate replayed history: an old voice turn + a current text turn.
history_voice = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('old audio turn'),
provider_message.ContentElement.from_file_base64('data:audio/wav;base64,BBBB', 'voice.wav'),
],
)
current = provider_message.Message(
role='user',
content=[provider_message.ContentElement.from_text('now do the csv')],
)
out = req._convert_messages([history_voice, current])
assert [p.get('type') for p in out[0]['content']] == ['text']
assert [p.get('type') for p in out[1]['content']] == ['text']
def test_convert_messages_plain_string_content_untouched():
req = _make_requester()
msg = provider_message.Message(role='user', content='just text')
out = req._convert_messages([msg])
assert out[0]['content'] == 'just text'
@@ -1062,9 +1062,7 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.side_effect = (
lambda model: {'max_input_tokens': 131072}
if model == 'moonshot/moonshot-v1-128k'
else {}
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {}
)
assert requester._safe_context_length('moonshot-v1-128k') == 131072
@@ -0,0 +1,146 @@
"""Unit tests for LocalAgentRunner._inject_inbound_attachments.
Covers the user -> sandbox attachment path added for the Box attachment
round-trip:
* materialized descriptors are stashed on the query and described to the model
via an appended text note (in-sandbox paths + outbox convention);
* non-image file parts (file_base64 / file_url) are stripped from the user
message content because OpenAI-compatible chat models reject them, while
image and text parts are kept for vision models;
* the helper is a no-op when the box service is unavailable or yields nothing,
and never raises into the chat turn on materialization failure.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
def _make_runner(box_service) -> LocalAgentRunner:
runner = LocalAgentRunner.__new__(LocalAgentRunner)
runner.ap = SimpleNamespace(logger=Mock(), box_service=box_service)
return runner
def _make_query():
return SimpleNamespace(variables={}, query_id='q-123')
def _box_service(attachments):
svc = SimpleNamespace(
available=True,
OUTBOX_MOUNT_DIR='/outbox',
materialize_inbound_attachments=AsyncMock(return_value=attachments),
)
return svc
@pytest.mark.asyncio
async def test_inject_strips_file_parts_and_appends_note():
box = _box_service([{'type': 'Voice', 'path': '/inbox/q-123/voice.wav', 'size': 176000}])
runner = _make_runner(box)
query = _make_query()
user_message = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('transcribe this'),
provider_message.ContentElement.from_file_base64('data:audio/wav;base64,AAAA', 'voice.wav'),
],
)
await runner._inject_inbound_attachments(query, user_message)
types = [getattr(ce, 'type', None) for ce in user_message.content]
# file_base64 dropped; text kept; sandbox-path note appended as text
assert 'file_base64' not in types
assert types.count('text') == 2
note = user_message.content[-1].text
assert '/inbox/q-123/voice.wav' in note
assert '/outbox/q-123' in note
# descriptors stashed for downstream stages
assert query.variables['_sandbox_inbound_attachments'] == box.materialize_inbound_attachments.return_value
@pytest.mark.asyncio
async def test_inject_keeps_image_parts():
box = _box_service([{'type': 'Image', 'path': '/inbox/q-123/pic.png', 'size': 1234}])
runner = _make_runner(box)
query = _make_query()
user_message = provider_message.Message(
role='user',
content=[
provider_message.ContentElement.from_text('what is this'),
provider_message.ContentElement.from_image_base64('data:image/png;base64,iVBORw0K'),
],
)
await runner._inject_inbound_attachments(query, user_message)
types = [getattr(ce, 'type', None) for ce in user_message.content]
assert 'image_base64' in types # vision part preserved
assert types[-1] == 'text' # note appended last
@pytest.mark.asyncio
async def test_inject_promotes_string_content_to_list_with_note():
box = _box_service([{'type': 'File', 'path': '/inbox/q-123/data.csv', 'size': 42}])
runner = _make_runner(box)
query = _make_query()
user_message = provider_message.Message(role='user', content='clean this csv')
await runner._inject_inbound_attachments(query, user_message)
assert isinstance(user_message.content, list)
assert [getattr(ce, 'type', None) for ce in user_message.content] == ['text', 'text']
assert user_message.content[0].text == 'clean this csv'
assert '/inbox/q-123/data.csv' in user_message.content[1].text
@pytest.mark.asyncio
async def test_inject_noop_without_box_service():
runner = _make_runner(box_service=None)
query = _make_query()
user_message = provider_message.Message(role='user', content='hello')
await runner._inject_inbound_attachments(query, user_message)
assert user_message.content == 'hello'
assert '_sandbox_inbound_attachments' not in query.variables
@pytest.mark.asyncio
async def test_inject_noop_when_no_attachments():
box = _box_service([])
runner = _make_runner(box)
query = _make_query()
user_message = provider_message.Message(role='user', content='hello')
await runner._inject_inbound_attachments(query, user_message)
assert user_message.content == 'hello'
assert '_sandbox_inbound_attachments' not in query.variables
@pytest.mark.asyncio
async def test_inject_swallows_materialization_error():
box = SimpleNamespace(
available=True,
OUTBOX_MOUNT_DIR='/outbox',
materialize_inbound_attachments=AsyncMock(side_effect=RuntimeError('disk full')),
)
runner = _make_runner(box)
query = _make_query()
user_message = provider_message.Message(role='user', content='hello')
# must not raise
await runner._inject_inbound_attachments(query, user_message)
assert user_message.content == 'hello'
runner.ap.logger.warning.assert_called_once()
@@ -635,7 +635,9 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
model_mgr = fake_requester_registry
@@ -648,7 +650,9 @@ async def test_model_manager_load_llm_model_with_provider(fake_requester_registr
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider_from_row(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
model_mgr = fake_requester_registry
@@ -661,7 +665,9 @@ async def test_model_manager_load_llm_model_with_provider_from_row(fake_requeste
@pytest.mark.asyncio
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_embedding_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
model_mgr = fake_requester_registry
@@ -43,6 +43,7 @@ class TestableRequester(requester.ProviderAPIRequester):
remove_think=False,
):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Testable response')],
@@ -289,7 +290,9 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
result = await provider.invoke_llm(query, runtime_llm_model, messages)
@@ -330,7 +333,9 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
chunks = []
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
@@ -576,7 +581,9 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
with pytest.raises(RequesterError):
await provider.invoke_llm(query, model, messages)
@@ -5,6 +5,7 @@ Tests cover:
- Conversation creation with prompts
- Session concurrency semaphore
"""
from __future__ import annotations
import pytest
@@ -60,11 +61,7 @@ class TestSessionManagerGetSession:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -173,11 +170,7 @@ class TestSessionManagerGetConversation:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -201,17 +194,13 @@ class TestSessionManagerGetConversation:
return query
@pytest.mark.asyncio
async def test_creates_conversation_with_prompt(
self, mock_app_with_config, sample_query, sample_session
):
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session):
"""Test that get_conversation creates conversation with prompt."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
@@ -234,21 +223,15 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
# First call creates conversation
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
# Second call with same pipeline should return same conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
assert conv1 is conv2
assert len(sample_session.conversations) == 1
@@ -262,36 +245,26 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
# First call with pipeline1
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1')
# Second call with different pipeline should create new conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2')
assert conv1 is not conv2
assert len(sample_session.conversations) == 2
assert sample_session.using_conversation is conv2
@pytest.mark.asyncio
async def test_conversation_has_empty_messages(
self, mock_app_with_config, sample_query, sample_session
):
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session):
"""Test that created conversation has empty messages list."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
@@ -300,22 +273,17 @@ class TestSessionManagerGetConversation:
assert conversation.messages == []
@pytest.mark.asyncio
async def test_prompt_messages_from_config(
self, mock_app_with_config, sample_query, sample_session
):
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session):
"""Test that prompt messages are created from prompt_config."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'System message'},
{'role': 'user', 'content': 'User message'}
]
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
)
assert conversation.prompt.name == 'default'
assert len(conversation.prompt.messages) == 2
assert len(conversation.prompt.messages) == 2
@@ -136,6 +136,7 @@ class TestToolManagerSchemaGeneration:
assert 'description' in func
assert 'parameters' in func
class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method."""
+2 -1
View File
@@ -3,6 +3,7 @@
Tests cover:
- _to_i18n_name() static method
"""
from __future__ import annotations
from importlib import import_module
@@ -60,4 +61,4 @@ class TestToI18nName:
kbmgr = get_kbmgr_module()
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
result = kbmgr.RAGManager._to_i18n_name(input_dict)
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
+14 -37
View File
@@ -6,6 +6,7 @@ Tests cover:
- Knowledge engine enrichment
- KB loading and removal
"""
from __future__ import annotations
import pytest
@@ -101,13 +102,9 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin error')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin error'))
manager = rag_module.RAGManager(mock_app)
@@ -128,9 +125,7 @@ class TestRAGManagerCreateKnowledgeBase:
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'plugin_id': 'author/engine'}]
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(return_value=[{'plugin_id': 'author/engine'}])
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock()
@@ -206,9 +201,7 @@ class TestRuntimeKnowledgeBaseOnKBCreate:
mock_app = create_mock_app()
mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Plugin failed')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Plugin failed'))
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -245,9 +238,7 @@ class TestRuntimeKnowledgeBaseIngestDocument:
mock_app = create_mock_app()
mock_kb = create_mock_kb_entity()
mock_app.plugin_connector.call_rag_ingest = AsyncMock(
return_value={'status': 'success'}
)
mock_app.plugin_connector.call_rag_ingest = AsyncMock(return_value={'status': 'success'})
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -304,14 +295,10 @@ class TestRAGManagerLoadKnowledgeBasesFromDB:
# KB that will cause initialize to fail
mock_kb = create_mock_kb_entity()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb])))
# Make initialize fail by having plugin_connector throw error
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(
side_effect=Exception('Init failed')
)
mock_app.plugin_connector.rag_on_kb_create = AsyncMock(side_effect=Exception('Init failed'))
manager = rag_module.RAGManager(mock_app)
# Should not raise - errors are caught
@@ -411,9 +398,7 @@ class TestRuntimeKnowledgeBaseRetrieve:
mock_kb = create_mock_kb_entity()
mock_kb.retrieval_settings = {}
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(
return_value={'results': []}
)
mock_app.plugin_connector.call_rag_retrieve = AsyncMock(return_value={'results': []})
runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb)
@@ -682,9 +667,7 @@ class TestRAGManagerGetAllDetails:
"""Test returns empty list when no knowledge bases."""
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
manager = rag_module.RAGManager(mock_app)
result = await manager.get_all_knowledge_base_details()
@@ -699,9 +682,7 @@ class TestRAGManagerGetAllDetails:
# Mock DB result
mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(all=Mock(return_value=[mock_kb_row]))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[mock_kb_row])))
mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
)
@@ -724,9 +705,7 @@ class TestRAGManagerGetDetails:
"""Test returns None when KB doesn't exist."""
rag_module = get_rag_module()
mock_app = create_mock_app()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=None))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=None)))
manager = rag_module.RAGManager(mock_app)
result = await manager.get_knowledge_base_details('nonexistent')
@@ -740,9 +719,7 @@ class TestRAGManagerGetDetails:
mock_app = create_mock_app()
mock_kb_row = Mock()
mock_app.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(first=Mock(return_value=mock_kb_row))
)
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(first=Mock(return_value=mock_kb_row)))
mock_app.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'kb1', 'knowledge_engine_plugin_id': 'author/engine'}
)
@@ -791,4 +768,4 @@ class TestRAGManagerLoadKnowledgeBase:
await manager.load_knowledge_base(kb_dict)
assert 'kb-uuid' in manager.knowledge_bases
assert 'kb-uuid' in manager.knowledge_bases
+7 -8
View File
@@ -121,10 +121,12 @@ class TestRAGRuntimeServiceVectorSearch:
"""Create mock app."""
mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.search = AsyncMock(return_value=[
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
])
mock_app.vector_db_mgr.search = AsyncMock(
return_value=[
{'id': 'id1', 'distance': 0.1, 'metadata': {'file_id': 'abc'}},
{'id': 'id2', 'distance': 0.2, 'metadata': {'file_id': 'def'}},
]
)
return mock_app
def _make_rag_import_mocks(self):
@@ -301,10 +303,7 @@ class TestRAGRuntimeServiceVectorList:
mock_app = MagicMock()
mock_app.vector_db_mgr = MagicMock()
mock_app.vector_db_mgr.list_by_filter = AsyncMock(
return_value=(
[{'id': 'id1', 'metadata': {'file_id': 'abc'}}],
10
)
return_value=([{'id': 'id1', 'metadata': {'file_id': 'abc'}}], 10)
)
return mock_app
@@ -21,8 +21,8 @@ from langbot.pkg.storage.providers.localstorage import LocalStorageProvider
@pytest.fixture
def storage_provider(tmp_path):
"""Create a LocalStorageProvider with a temporary storage path."""
storage_path = str(tmp_path / "storage")
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
storage_path = str(tmp_path / 'storage')
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
mock_app = Mock()
provider = LocalStorageProvider(mock_app)
yield provider, storage_path
@@ -35,15 +35,15 @@ class TestPathTraversalPrevention:
async def test_absolute_path_save_rejected(self, storage_provider, tmp_path):
"""Saving with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "pwned.txt")
target_file = str(tmp_path / 'pwned.txt')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError)):
await provider.save(target_file, b"malicious content")
await provider.save(target_file, b'malicious content')
# The file must NOT exist outside the storage directory
assert not os.path.exists(target_file), (
f"Path traversal succeeded: file was written outside storage to {target_file}"
f'Path traversal succeeded: file was written outside storage to {target_file}'
)
@pytest.mark.asyncio
@@ -52,32 +52,28 @@ class TestPathTraversalPrevention:
provider, storage_path = storage_provider
# Create a file outside the storage directory
target_file = str(tmp_path / "secret.txt")
with open(target_file, "wb") as f:
f.write(b"secret data")
target_file = str(tmp_path / 'secret.txt')
with open(target_file, 'wb') as f:
f.write(b'secret data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(target_file)
assert data != b"secret data", (
"Path traversal succeeded: read file outside storage"
)
assert data != b'secret data', 'Path traversal succeeded: read file outside storage'
@pytest.mark.asyncio
async def test_absolute_path_exists_rejected(self, storage_provider, tmp_path):
"""Exists check with an absolute path key must be blocked or return False."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "check_me.txt")
with open(target_file, "wb") as f:
f.write(b"data")
target_file = str(tmp_path / 'check_me.txt')
with open(target_file, 'wb') as f:
f.write(b'data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
try:
result = await provider.exists(target_file)
assert result is False, (
"Path traversal succeeded: exists() returned True for file outside storage"
)
assert result is False, 'Path traversal succeeded: exists() returned True for file outside storage'
except (ValueError, PermissionError):
pass # Expected
@@ -86,28 +82,26 @@ class TestPathTraversalPrevention:
"""Deleting with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "do_not_delete.txt")
with open(target_file, "wb") as f:
f.write(b"important data")
target_file = str(tmp_path / 'do_not_delete.txt')
with open(target_file, 'wb') as f:
f.write(b'important data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.delete(target_file)
assert os.path.exists(target_file), (
"Path traversal succeeded: file outside storage was deleted"
)
assert os.path.exists(target_file), 'Path traversal succeeded: file outside storage was deleted'
@pytest.mark.asyncio
async def test_absolute_path_size_rejected(self, storage_provider, tmp_path):
"""Size check with an absolute path key must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "measure_me.txt")
with open(target_file, "wb") as f:
f.write(b"some data")
target_file = str(tmp_path / 'measure_me.txt')
with open(target_file, 'wb') as f:
f.write(b'some data')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
await provider.size(target_file)
@@ -116,41 +110,39 @@ class TestPathTraversalPrevention:
"""Relative path traversal with '..' must be blocked."""
provider, storage_path = storage_provider
target_file = str(tmp_path / "above_storage.txt")
with open(target_file, "wb") as f:
f.write(b"above storage secret")
target_file = str(tmp_path / 'above_storage.txt')
with open(target_file, 'wb') as f:
f.write(b'above storage secret')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
relative_key = os.path.join("..", "above_storage.txt")
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
relative_key = os.path.join('..', 'above_storage.txt')
with pytest.raises((ValueError, PermissionError, FileNotFoundError)):
data = await provider.load(relative_key)
assert data != b"above storage secret"
assert data != b'above storage secret'
@pytest.mark.asyncio
async def test_delete_dir_recursive_traversal_rejected(self, storage_provider, tmp_path):
"""delete_dir_recursive with traversal path must be blocked."""
provider, storage_path = storage_provider
outside_dir = tmp_path / "outside_dir"
outside_dir = tmp_path / 'outside_dir'
outside_dir.mkdir()
(outside_dir / "file.txt").write_text("important")
(outside_dir / 'file.txt').write_text('important')
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
with pytest.raises((ValueError, PermissionError)):
await provider.delete_dir_recursive(str(outside_dir))
assert outside_dir.exists(), (
"Path traversal succeeded: directory outside storage was deleted"
)
assert outside_dir.exists(), 'Path traversal succeeded: directory outside storage was deleted'
@pytest.mark.asyncio
async def test_legitimate_key_works(self, storage_provider):
"""Normal keys without traversal must still work."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = "test_image_abc123.png"
content = b"PNG image data"
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
key = 'test_image_abc123.png'
content = b'PNG image data'
await provider.save(key, content)
assert await provider.exists(key) is True
@@ -166,9 +158,9 @@ class TestPathTraversalPrevention:
"""Keys with legitimate subdirectories must still work."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
key = "bot_log_images/img_001.png"
content = b"PNG image data"
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
key = 'bot_log_images/img_001.png'
content = b'PNG image data'
await provider.save(key, content)
assert await provider.exists(key) is True
@@ -181,33 +173,33 @@ class TestPathTraversalPrevention:
"""delete_dir_recursive should handle non-existing directories gracefully."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
# Try to delete a non-existing directory - should not raise
await provider.delete_dir_recursive("nonexistent_dir")
await provider.delete_dir_recursive('nonexistent_dir')
@pytest.mark.asyncio
async def test_delete_dir_recursive_with_files(self, storage_provider):
"""delete_dir_recursive should delete directory with files inside."""
provider, storage_path = storage_provider
with patch("langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH", storage_path):
with patch('langbot.pkg.storage.providers.localstorage.LOCAL_STORAGE_PATH', storage_path):
# Create a directory with files
key1 = "test_dir/file1.txt"
key2 = "test_dir/file2.txt"
await provider.save(key1, b"content1")
await provider.save(key2, b"content2")
key1 = 'test_dir/file1.txt'
key2 = 'test_dir/file2.txt'
await provider.save(key1, b'content1')
await provider.save(key2, b'content2')
# Verify files exist
assert await provider.exists(key1)
assert await provider.exists(key2)
# Delete directory recursively
await provider.delete_dir_recursive("test_dir")
await provider.delete_dir_recursive('test_dir')
# Verify files no longer exist
assert not await provider.exists(key1)
assert not await provider.exists(key2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])
+4 -1
View File
@@ -8,6 +8,7 @@ Tests cover:
Uses moto library to mock AWS S3 service.
"""
from __future__ import annotations
import pytest
@@ -44,8 +45,10 @@ def mock_app_with_s3_config():
def s3_mock():
"""Set up moto S3 mock context."""
from moto import mock_aws
with mock_aws():
import boto3
# Create bucket for tests that need pre-existing bucket
s3 = boto3.client('s3', region_name='us-east-1')
yield s3
@@ -325,4 +328,4 @@ class TestS3StorageProviderErrorHandling:
await provider.initialize()
with pytest.raises(Exception):
await provider.size('nonexistent.txt')
await provider.size('nonexistent.txt')
@@ -31,7 +31,7 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
mock_app.logger.info.assert_called()
@@ -41,12 +41,12 @@ class TestStorageMgr:
"""Should use local storage when explicitly configured."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {"storage": {"use": "local"}}
mock_app.instance_config.data = {'storage': {'use': 'local'}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -55,14 +55,12 @@ class TestStorageMgr:
"""Should use S3 storage when configured."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
"storage": {"use": "s3", "s3": {"endpoint_url": "https://s3.amazonaws.com"}}
}
mock_app.instance_config.data = {'storage': {'use': 's3', 's3': {'endpoint_url': 'https://s3.amazonaws.com'}}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(S3StorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(S3StorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, S3StorageProvider)
@@ -71,12 +69,12 @@ class TestStorageMgr:
"""Should default to local storage for invalid storage type."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {"storage": {"use": "invalid_type"}}
mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}}
mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app)
with patch.object(LocalStorageProvider, "initialize", new_callable=AsyncMock):
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock):
await storage_mgr.initialize()
assert isinstance(storage_mgr.storage_provider, LocalStorageProvider)
@@ -90,9 +88,7 @@ class TestStorageMgr:
storage_mgr = StorageMgr(mock_app)
with patch.object(
LocalStorageProvider, "initialize", new_callable=AsyncMock
) as mock_init:
with patch.object(LocalStorageProvider, 'initialize', new_callable=AsyncMock) as mock_init:
await storage_mgr.initialize()
mock_init.assert_called_once()
@@ -105,8 +101,8 @@ class TestStorageProviderBase:
mock_app = Mock()
# Use LocalStorageProvider as concrete implementation
with patch("os.path.exists", return_value=True):
with patch("os.makedirs"):
with patch('os.path.exists', return_value=True):
with patch('os.makedirs'):
provider = LocalStorageProvider(mock_app)
assert provider.ap == mock_app
@@ -115,12 +111,12 @@ class TestStorageProviderBase:
"""Provider base initialize should be callable and do nothing."""
mock_app = Mock()
with patch("os.path.exists", return_value=True):
with patch("os.makedirs"):
with patch('os.path.exists', return_value=True):
with patch('os.makedirs'):
provider = LocalStorageProvider(mock_app)
# Initialize should not raise
await provider.initialize()
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])
+7 -9
View File
@@ -8,6 +8,7 @@ Tests cover:
- HTTP request success/failure scenarios
- Source code bug: send_tasks should be instance variable
"""
from __future__ import annotations
import pytest
@@ -38,6 +39,7 @@ class TestTelemetryManagerInit:
manager = telemetry.TelemetryManager(mock_app)
assert manager.telemetry_config == {}
class TestTelemetryManagerInitialize:
"""Tests for initialize() method."""
@@ -218,7 +220,7 @@ class TestPayloadSanitization:
# All null string fields should be empty strings
for field in ['adapter', 'runner', 'runner_category', 'model_name', 'version', 'edition', 'error', 'timestamp']:
assert result[field] == '', f"Field {field} should be empty string, got {result[field]}"
assert result[field] == '', f'Field {field} should be empty string, got {result[field]}'
@pytest.mark.asyncio
async def test_sanitize_string_fields_preserve_values(self):
@@ -418,9 +420,7 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock(
status_code=200,
text='{"code": 0, "msg": "success"}',
json=Mock(return_value={'code': 0, 'msg': 'success'})
status_code=200, text='{"code": 0, "msg": "success"}', json=Mock(return_value={'code': 0, 'msg': 'success'})
)
mock_client = Mock()
@@ -448,9 +448,7 @@ class TestHTTPScenarios:
manager.telemetry_config = {'url': 'https://example.com'}
mock_response = Mock(
status_code=500,
text='Internal Server Error',
json=Mock(return_value={'code': 500, 'msg': 'error'})
status_code=500, text='Internal Server Error', json=Mock(return_value={'code': 500, 'msg': 'error'})
)
mock_client = Mock()
@@ -478,7 +476,7 @@ class TestHTTPScenarios:
mock_response = Mock(
status_code=200,
text='{"code": 400, "msg": "Bad Request"}',
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'})
json=Mock(return_value={'code': 400, 'msg': 'Bad Request'}),
)
mock_client = Mock()
@@ -493,7 +491,7 @@ class TestHTTPScenarios:
assert mock_app.logger.warning.call_count >= 1
# Check that one of the calls contains application error info
all_warnings = [call[0][0] for call in mock_app.logger.warning.call_args_list]
assert any('400' in w for w in all_warnings), f"No warning contained error code 400: {all_warnings}"
assert any('400' in w for w in all_warnings), f'No warning contained error code 400: {all_warnings}'
@pytest.mark.asyncio
async def test_send_timeout_logs_warning(self):
@@ -9,6 +9,7 @@ Tests cover:
Note: Do NOT use 'from __future__ import annotations' because
funcschema.py expects actual type objects, not string annotations.
"""
import pytest
from importlib import import_module
+30 -32
View File
@@ -20,55 +20,53 @@ class TestGetQQImageDownloadableUrl:
def test_basic_url(self):
"""Parse basic image URL."""
url = "http://example.com/image.jpg"
url = 'http://example.com/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/image.jpg"
assert result_url == 'http://example.com/image.jpg'
assert query == {}
def test_url_with_query_params(self):
"""Parse URL with query parameters."""
url = "http://example.com/image.jpg?param1=value1&param2=value2"
url = 'http://example.com/image.jpg?param1=value1&param2=value2'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/image.jpg"
assert query == {"param1": ["value1"], "param2": ["value2"]}
assert result_url == 'http://example.com/image.jpg'
assert query == {'param1': ['value1'], 'param2': ['value2']}
def test_url_with_port(self):
"""Parse URL with port number."""
url = "http://example.com:8080/image.jpg"
url = 'http://example.com:8080/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com:8080/image.jpg"
assert result_url == 'http://example.com:8080/image.jpg'
def test_url_with_path(self):
"""Parse URL with complex path."""
url = "http://example.com/path/to/image.jpg"
url = 'http://example.com/path/to/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "http://example.com/path/to/image.jpg"
assert result_url == 'http://example.com/path/to/image.jpg'
def test_url_with_fragment(self):
"""Parse URL with fragment (fragment is not part of query)."""
url = "http://example.com/image.jpg#fragment"
url = 'http://example.com/image.jpg#fragment'
result_url, query = get_qq_image_downloadable_url(url)
# Fragment is not included in query string parsing
assert "http://example.com/image.jpg" in result_url
assert 'http://example.com/image.jpg' in result_url
def test_https_url(self):
"""Parse HTTPS URL and preserve its scheme."""
url = "https://example.com/image.jpg"
url = 'https://example.com/image.jpg'
result_url, query = get_qq_image_downloadable_url(url)
assert result_url == "https://example.com/image.jpg"
assert result_url == 'https://example.com/image.jpg'
assert query == {}
def test_preserves_qq_https_scheme_and_query(self):
"""QQ image URLs keep HTTPS and query parameters."""
result_url, query = get_qq_image_downloadable_url(
'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1'
)
result_url, query = get_qq_image_downloadable_url('https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1')
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
assert query == {'term': ['2'], 'is_origin': ['1']}
@@ -88,50 +86,50 @@ class TestExtractB64AndFormat:
async def test_jpeg_data_uri(self):
"""Extract base64 and format from JPEG data URI."""
# Create a simple base64 string
original_data = b"test image data"
original_data = b'test image data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/jpeg;base64,{b64_data}"
data_uri = f'data:image/jpeg;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "jpeg"
assert result_format == 'jpeg'
@pytest.mark.asyncio
async def test_png_data_uri(self):
"""Extract base64 and format from PNG data URI."""
original_data = b"test png data"
original_data = b'test png data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/png;base64,{b64_data}"
data_uri = f'data:image/png;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "png"
assert result_format == 'png'
@pytest.mark.asyncio
async def test_gif_data_uri(self):
"""Extract base64 and format from GIF data URI."""
original_data = b"test gif data"
original_data = b'test gif data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/gif;base64,{b64_data}"
data_uri = f'data:image/gif;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "gif"
assert result_format == 'gif'
@pytest.mark.asyncio
async def test_webp_data_uri(self):
"""Extract base64 and format from WebP data URI."""
original_data = b"test webp data"
original_data = b'test webp data'
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/webp;base64,{b64_data}"
data_uri = f'data:image/webp;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == b64_data
assert result_format == "webp"
assert result_format == 'webp'
@pytest.mark.asyncio
async def test_complex_base64(self):
@@ -139,7 +137,7 @@ class TestExtractB64AndFormat:
# Base64 can include + and / characters
original_data = bytes(range(256)) # All byte values
b64_data = base64.b64encode(original_data).decode()
data_uri = f"data:image/png;base64,{b64_data}"
data_uri = f'data:image/png;base64,{b64_data}'
result_b64, result_format = await extract_b64_and_format(data_uri)
@@ -150,9 +148,9 @@ class TestExtractB64AndFormat:
@pytest.mark.asyncio
async def test_empty_base64(self):
"""Handle empty base64 string."""
data_uri = "data:image/png;base64,"
data_uri = 'data:image/png;base64,'
result_b64, result_format = await extract_b64_and_format(data_uri)
assert result_b64 == ""
assert result_format == "png"
assert result_b64 == ''
assert result_format == 'png'
+46 -46
View File
@@ -23,52 +23,52 @@ class TestImportDir:
def test_calls_importlib_for_each_python_file(self, tmp_path):
"""Should call importlib.import_module for each .py file."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "__init__.py").write_text("")
(module_dir / "module_a.py").write_text("VALUE_A = 'a'\n")
(module_dir / "module_b.py").write_text("VALUE_B = 'b'\n")
(module_dir / "readme.txt").write_text("not a module")
(module_dir / '__init__.py').write_text('')
(module_dir / 'module_a.py').write_text("VALUE_A = 'a'\n")
(module_dir / 'module_b.py').write_text("VALUE_B = 'b'\n")
(module_dir / 'readme.txt').write_text('not a module')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# Should call import_module for each .py file (excluding __init__.py)
assert mock_import.call_count == 2
def test_skips_init_py(self, tmp_path):
"""Should skip __init__.py when importing."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "__init__.py").write_text("")
(module_dir / "regular.py").write_text("VALUE = 1\n")
(module_dir / '__init__.py').write_text('')
(module_dir / 'regular.py').write_text('VALUE = 1\n')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# __init__.py should be skipped
mock_import.assert_called_once()
# The call should not include __init__
call_args = mock_import.call_args[0][0]
assert "__init__" not in call_args
assert '__init__' not in call_args
def test_ignores_non_py_files(self, tmp_path):
"""Should ignore non-.py files."""
module_dir = tmp_path / "test_modules"
module_dir = tmp_path / 'test_modules'
module_dir.mkdir()
(module_dir / "module.py").write_text("VALUE = 1\n")
(module_dir / "readme.txt").write_text("text")
(module_dir / "data.json").write_text("{}")
(module_dir / 'module.py').write_text('VALUE = 1\n')
(module_dir / 'readme.txt').write_text('text')
(module_dir / 'data.json').write_text('{}')
from langbot.pkg.utils import importutil
with patch.object(importlib, "import_module") as mock_import:
importutil.import_dir(str(module_dir), path_prefix="test_prefix.")
with patch.object(importlib, 'import_module') as mock_import:
importutil.import_dir(str(module_dir), path_prefix='test_prefix.')
# Only .py files should be imported
assert mock_import.call_count == 1
@@ -79,14 +79,14 @@ class TestImportModulesInPkg:
def test_imports_modules_from_package(self, tmp_path):
"""Should import all modules from a package object."""
mock_pkg = MagicMock()
mock_pkg.__file__ = str(tmp_path / "__init__.py")
mock_pkg.__file__ = str(tmp_path / '__init__.py')
(tmp_path / "__init__.py").write_text("")
(tmp_path / "mod1.py").write_text("MOD1 = 1\n")
(tmp_path / '__init__.py').write_text('')
(tmp_path / 'mod1.py').write_text('MOD1 = 1\n')
from langbot.pkg.utils import importutil
with patch.object(importutil, "import_dir") as mock_import_dir:
with patch.object(importutil, 'import_dir') as mock_import_dir:
importutil.import_modules_in_pkg(mock_pkg)
mock_import_dir.assert_called_once()
call_path = mock_import_dir.call_args[0][0]
@@ -101,11 +101,11 @@ class TestImportModulesInPkgs:
from langbot.pkg.utils import importutil
mock_pkg1 = MagicMock()
mock_pkg1.__file__ = "/path/to/pkg1/__init__.py"
mock_pkg1.__file__ = '/path/to/pkg1/__init__.py'
mock_pkg2 = MagicMock()
mock_pkg2.__file__ = "/path/to/pkg2/__init__.py"
mock_pkg2.__file__ = '/path/to/pkg2/__init__.py'
with patch.object(importutil, "import_modules_in_pkg") as mock_import:
with patch.object(importutil, 'import_modules_in_pkg') as mock_import:
importutil.import_modules_in_pkgs([mock_pkg1, mock_pkg2])
assert mock_import.call_count == 2
@@ -116,18 +116,18 @@ class TestImportDotStyleDir:
def test_converts_dot_notation_to_path(self, tmp_path):
"""Should convert dot notation to path and import."""
# Create structure matching the dot notation
(tmp_path / "my").mkdir()
(tmp_path / "my" / "pkg").mkdir()
(tmp_path / "my" / "pkg" / "test").mkdir()
(tmp_path / 'my').mkdir()
(tmp_path / 'my' / 'pkg').mkdir()
(tmp_path / 'my' / 'pkg' / 'test').mkdir()
from langbot.pkg.utils import importutil
with patch.object(importutil, "import_dir") as mock_import_dir:
importutil.import_dot_style_dir("my.pkg.test")
with patch.object(importutil, 'import_dir') as mock_import_dir:
importutil.import_dot_style_dir('my.pkg.test')
# The path should be converted using os.path.join
call_path = mock_import_dir.call_args[0][0]
# Should contain the path components joined
assert "my" in call_path
assert 'my' in call_path
class TestReadResourceFile:
@@ -137,16 +137,16 @@ class TestReadResourceFile:
"""Should read content from a resource file."""
from langbot.pkg.utils import importutil
content = importutil.read_resource_file("templates/config.yaml")
assert "admins:" in content
assert "edition: community" in content
content = importutil.read_resource_file('templates/config.yaml')
assert 'admins:' in content
assert 'edition: community' in content
def test_raises_for_nonexistent_file(self):
"""Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file("nonexistent/path/file.txt")
importutil.read_resource_file('nonexistent/path/file.txt')
class TestReadResourceFileBytes:
@@ -156,16 +156,16 @@ class TestReadResourceFileBytes:
"""Should read content as bytes from a resource file."""
from langbot.pkg.utils import importutil
content = importutil.read_resource_file_bytes("templates/config.yaml")
assert b"admins:" in content
assert b"edition: community" in content
content = importutil.read_resource_file_bytes('templates/config.yaml')
assert b'admins:' in content
assert b'edition: community' in content
def test_raises_for_nonexistent_file_bytes(self):
"""Should raise exception for non-existent resource file."""
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.read_resource_file_bytes("nonexistent/path/file.txt")
importutil.read_resource_file_bytes('nonexistent/path/file.txt')
class TestListResourceFiles:
@@ -175,9 +175,9 @@ class TestListResourceFiles:
"""Should list files in a resource directory."""
from langbot.pkg.utils import importutil
files = importutil.list_resource_files("templates")
assert "config.yaml" in files
assert "default-pipeline-config.json" in files
files = importutil.list_resource_files('templates')
assert 'config.yaml' in files
assert 'default-pipeline-config.json' in files
assert all(isinstance(file, str) for file in files)
def test_raises_for_nonexistent_directory(self):
@@ -185,8 +185,8 @@ class TestListResourceFiles:
from langbot.pkg.utils import importutil
with pytest.raises((FileNotFoundError, Exception)):
importutil.list_resource_files("nonexistent_directory_xyz")
importutil.list_resource_files('nonexistent_directory_xyz')
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])
+2 -1
View File
@@ -5,6 +5,7 @@ Tests cover:
- Docker environment detection
- WebSocket plugin runtime mode
"""
from __future__ import annotations
import os
@@ -86,4 +87,4 @@ class TestGetPlatform:
assert platform_module.use_websocket_to_connect_plugin_runtime() is True
# Restore
platform_module.standalone_runtime = original
platform_module.standalone_runtime = original
+17 -11
View File
@@ -60,10 +60,12 @@ class TestProxyManager:
async def test_initialize_config_overrides_env(self):
"""Config proxy overrides environment variables."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://config-proxy:8080',
'https': 'https://config-proxy:8443',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://config-proxy:8080',
'https': 'https://config-proxy:8443',
}
)
with patch.dict(os.environ, {'HTTP_PROXY': 'http://env-proxy:8080'}):
pm = ProxyManager(mock_app)
@@ -74,10 +76,12 @@ class TestProxyManager:
async def test_initialize_sets_env_variables(self):
"""initialize sets proxy to environment variables."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://test-proxy:8080',
'https': 'https://test-proxy:8443',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://test-proxy:8080',
'https': 'https://test-proxy:8443',
}
)
pm = ProxyManager(mock_app)
await pm.initialize()
@@ -143,9 +147,11 @@ class TestProxyManager:
async def test_initialize_http_only_config(self):
"""initialize handles http-only config."""
mock_app = self._create_mock_app(proxy_config={
'http': 'http://http-only:8080',
})
mock_app = self._create_mock_app(
proxy_config={
'http': 'http://http-only:8080',
}
)
# Clear any existing proxy env vars
env_backup = {}
+69 -88
View File
@@ -29,63 +29,63 @@ class TestGetRunnerCategory:
def test_empty_url_returns_unknown(self):
"""Empty or None URL should return UNKNOWN."""
assert get_runner_category("test", "") == RunnerCategory.UNKNOWN
assert get_runner_category("test", None) == RunnerCategory.UNKNOWN
assert get_runner_category('test', '') == RunnerCategory.UNKNOWN
assert get_runner_category('test', None) == RunnerCategory.UNKNOWN
def test_localhost_returns_local(self):
"""localhost URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://localhost:3000") == RunnerCategory.LOCAL
assert get_runner_category("test", "https://localhost") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://localhost:3000') == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://localhost') == RunnerCategory.LOCAL
def test_127_0_0_1_returns_local(self):
"""127.0.0.1 URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://127.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "https://127.0.0.1") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://127.0.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'https://127.0.0.1') == RunnerCategory.LOCAL
def test_0_0_0_0_returns_local(self):
"""0.0.0.0 URL should be categorized as LOCAL."""
assert get_runner_category("test", "http://0.0.0.0:8080") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://0.0.0.0:8080') == RunnerCategory.LOCAL
def test_private_ip_192_168_returns_local(self):
"""192.168.x.x private IP should be categorized as LOCAL."""
assert get_runner_category("test", "http://192.168.1.1:3000") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://192.168.0.100") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://192.168.1.1:3000') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://192.168.0.100') == RunnerCategory.LOCAL
def test_private_ip_10_returns_local(self):
"""10.x.x.x private IP should be categorized as LOCAL."""
assert get_runner_category("test", "http://10.0.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://10.255.255.255") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://10.0.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://10.255.255.255') == RunnerCategory.LOCAL
def test_private_ip_172_16_31_returns_local(self):
"""172.16.x.x - 172.31.x.x private IP range should be categorized as LOCAL."""
assert get_runner_category("test", "http://172.16.0.1:8080") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://172.20.0.1") == RunnerCategory.LOCAL
assert get_runner_category("test", "http://172.31.255.255") == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.16.0.1:8080') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.20.0.1') == RunnerCategory.LOCAL
assert get_runner_category('test', 'http://172.31.255.255') == RunnerCategory.LOCAL
def test_n8n_cloud_returns_cloud(self):
"""n8n.cloud domain should be categorized as CLOUD."""
assert get_runner_category("test", "https://myinstance.n8n.cloud") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://test.n8n.io") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://myinstance.n8n.cloud') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.n8n.io') == RunnerCategory.CLOUD
def test_dify_cloud_returns_cloud(self):
"""Dify cloud domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://api.dify.ai/v1") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://cloud.dify.ai") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.dify.ai/v1') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://cloud.dify.ai') == RunnerCategory.CLOUD
def test_coze_cloud_returns_cloud(self):
"""Coze domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://api.coze.com") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://api.coze.cn") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.coze.com') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://api.coze.cn') == RunnerCategory.CLOUD
def test_langflow_cloud_returns_cloud(self):
"""Langflow domains should be categorized as CLOUD."""
assert get_runner_category("test", "https://cloud.langflow.ai") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://test.langflow.org") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://cloud.langflow.ai') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://test.langflow.org') == RunnerCategory.CLOUD
def test_other_url_returns_cloud(self):
"""Other URLs should default to CLOUD category."""
assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD
assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://example.com') == RunnerCategory.CLOUD
assert get_runner_category('test', 'https://myserver.example.org') == RunnerCategory.CLOUD
@pytest.mark.parametrize(
'runner_url',
@@ -101,7 +101,7 @@ class TestGetRunnerCategory:
)
def test_invalid_urls_return_unknown(self, runner_url):
"""Invalid or incomplete URLs should return UNKNOWN."""
assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
assert get_runner_category('test', runner_url) == RunnerCategory.UNKNOWN
def test_urlparse_exception_returns_unknown(self):
"""Exception during URL parsing should return UNKNOWN."""
@@ -109,15 +109,15 @@ class TestGetRunnerCategory:
from langbot.pkg.utils import runner
def mock_urlparse(url):
raise Exception("URL parsing failed")
raise Exception('URL parsing failed')
with patch("langbot.pkg.utils.runner.urlparse", side_effect=mock_urlparse):
result = runner.get_runner_category("test", "http://example.com")
with patch('langbot.pkg.utils.runner.urlparse', side_effect=mock_urlparse):
result = runner.get_runner_category('test', 'http://example.com')
assert result == RunnerCategory.UNKNOWN
def test_url_without_scheme_returns_unknown(self):
"""URL without scheme should return UNKNOWN."""
assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN
assert get_runner_category('test', 'example.com') == RunnerCategory.UNKNOWN
@pytest.mark.parametrize(
'runner_url',
@@ -146,20 +146,21 @@ class TestGetRunnerCategory:
"""Domain names that only look like private IP prefixes should not be LOCAL."""
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
class TestIsCloudRunner:
"""Test is_cloud_runner helper function."""
def test_cloud_runner_returns_true(self):
"""Cloud URL should return True."""
assert is_cloud_runner("test", "https://api.dify.ai") is True
assert is_cloud_runner('test', 'https://api.dify.ai') is True
def test_local_runner_returns_false(self):
"""Local URL should return False."""
assert is_cloud_runner("test", "http://localhost:3000") is False
assert is_cloud_runner('test', 'http://localhost:3000') is False
def test_unknown_returns_false(self):
"""Unknown category should return False."""
assert is_cloud_runner("test", None) is False
assert is_cloud_runner('test', None) is False
class TestIsLocalRunner:
@@ -167,15 +168,15 @@ class TestIsLocalRunner:
def test_local_runner_returns_true(self):
"""Local URL should return True."""
assert is_local_runner("test", "http://localhost:3000") is True
assert is_local_runner('test', 'http://localhost:3000') is True
def test_cloud_runner_returns_false(self):
"""Cloud URL should return False."""
assert is_local_runner("test", "https://api.dify.ai") is False
assert is_local_runner('test', 'https://api.dify.ai') is False
def test_unknown_returns_false(self):
"""Unknown category should return False."""
assert is_local_runner("test", None) is False
assert is_local_runner('test', None) is False
class TestGetRunnerInfo:
@@ -183,17 +184,17 @@ class TestGetRunnerInfo:
def test_returns_dict_with_expected_keys(self):
"""Should return dict with name, url, and category keys."""
info = get_runner_info("my-runner", "http://localhost:3000")
assert "name" in info
assert "url" in info
assert "category" in info
info = get_runner_info('my-runner', 'http://localhost:3000')
assert 'name' in info
assert 'url' in info
assert 'category' in info
def test_includes_correct_values(self):
"""Should include correct values in dict."""
info = get_runner_info("my-runner", "http://localhost:3000")
assert info["name"] == "my-runner"
assert info["url"] == "http://localhost:3000"
assert info["category"] == RunnerCategory.LOCAL
info = get_runner_info('my-runner', 'http://localhost:3000')
assert info['name'] == 'my-runner'
assert info['url'] == 'http://localhost:3000'
assert info['category'] == RunnerCategory.LOCAL
class TestExtractRunnerUrl:
@@ -203,74 +204,58 @@ class TestExtractRunnerUrl:
"""Should extract base-url from dify-service-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
url = extract_runner_url("dify-service-api", runner, pipeline_config)
assert url == "https://api.dify.ai"
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
url = extract_runner_url('dify-service-api', runner, pipeline_config)
assert url == 'https://api.dify.ai'
def test_n8n_service_api_extracts_url(self):
"""Should extract webhook-url from n8n-service-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"n8n-service-api": {"webhook-url": "https://my.n8n.cloud/webhook"}
}
}
url = extract_runner_url("n8n-service-api", runner, pipeline_config)
assert url == "https://my.n8n.cloud/webhook"
pipeline_config = {'ai': {'n8n-service-api': {'webhook-url': 'https://my.n8n.cloud/webhook'}}}
url = extract_runner_url('n8n-service-api', runner, pipeline_config)
assert url == 'https://my.n8n.cloud/webhook'
def test_coze_api_extracts_url(self):
"""Should extract api-base from coze-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"coze-api": {"api-base": "https://api.coze.com"}
}
}
url = extract_runner_url("coze-api", runner, pipeline_config)
assert url == "https://api.coze.com"
pipeline_config = {'ai': {'coze-api': {'api-base': 'https://api.coze.com'}}}
url = extract_runner_url('coze-api', runner, pipeline_config)
assert url == 'https://api.coze.com'
def test_langflow_api_extracts_url(self):
"""Should extract base-url from langflow-api config."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"langflow-api": {"base-url": "https://cloud.langflow.ai"}
}
}
url = extract_runner_url("langflow-api", runner, pipeline_config)
assert url == "https://cloud.langflow.ai"
pipeline_config = {'ai': {'langflow-api': {'base-url': 'https://cloud.langflow.ai'}}}
url = extract_runner_url('langflow-api', runner, pipeline_config)
assert url == 'https://cloud.langflow.ai'
def test_unknown_runner_returns_none(self):
"""Unknown runner name should return None."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {}
url = extract_runner_url("unknown-runner", runner, pipeline_config)
url = extract_runner_url('unknown-runner', runner, pipeline_config)
assert url is None
def test_none_runner_returns_none(self):
"""None runner should return None."""
url = extract_runner_url("test", None, {})
url = extract_runner_url('test', None, {})
assert url is None
def test_runner_without_pipeline_config_returns_none(self):
"""Runner without pipeline_config attribute should return None."""
runner = Mock(spec=[]) # Empty spec means no attributes
url = extract_runner_url("test", runner, {})
url = extract_runner_url('test', runner, {})
assert url is None
def test_none_pipeline_config_returns_none(self):
"""None pipeline_config should return None."""
runner = Mock()
runner.pipeline_config = {}
url = extract_runner_url("dify-service-api", runner, None)
url = extract_runner_url('dify-service-api', runner, None)
assert url is None
def test_missing_ai_config_returns_none(self):
@@ -278,7 +263,7 @@ class TestExtractRunnerUrl:
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {}
url = extract_runner_url("dify-service-api", runner, pipeline_config)
url = extract_runner_url('dify-service-api', runner, pipeline_config)
assert url is None
@@ -289,19 +274,15 @@ class TestGetRunnerCategoryFromRunner:
"""Should extract URL and return correct category."""
runner = Mock()
runner.pipeline_config = {}
pipeline_config = {
"ai": {
"dify-service-api": {"base-url": "https://api.dify.ai"}
}
}
category = get_runner_category_from_runner("dify-service-api", runner, pipeline_config)
pipeline_config = {'ai': {'dify-service-api': {'base-url': 'https://api.dify.ai'}}}
category = get_runner_category_from_runner('dify-service-api', runner, pipeline_config)
assert category == RunnerCategory.CLOUD
def test_returns_unknown_for_missing_url(self):
"""Should return UNKNOWN when URL cannot be extracted."""
runner = Mock()
runner.pipeline_config = {}
category = get_runner_category_from_runner("unknown", runner, {})
category = get_runner_category_from_runner('unknown', runner, {})
assert category == RunnerCategory.UNKNOWN
@@ -310,9 +291,9 @@ class TestConstants:
def test_runner_category_constants(self):
"""RunnerCategory should have LOCAL, CLOUD, UNKNOWN."""
assert RunnerCategory.LOCAL == "local"
assert RunnerCategory.CLOUD == "cloud"
assert RunnerCategory.UNKNOWN == "unknown"
assert RunnerCategory.LOCAL == 'local'
assert RunnerCategory.CLOUD == 'cloud'
assert RunnerCategory.UNKNOWN == 'unknown'
def test_cloud_domains_not_empty(self):
"""CLOUD_DOMAINS should not be empty."""
@@ -323,5 +304,5 @@ class TestConstants:
assert len(LOCAL_PATTERNS) > 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])
if __name__ == '__main__':
pytest.main([__file__, '-v'])

Some files were not shown because too many files have changed in this diff Show More