Compare commits

..

12 Commits

Author SHA1 Message Date
huanghuoguoguo
8789c42eeb feat(monitoring): add host RAG trace observability 2026-06-17 00:13:57 +08:00
huanghuoguoguo
b3c6de2072 [codex] cover frontend CRUD smoke flows (#2253)
* test: cover frontend CRUD smoke flows

* test: add bot CRUD smoke coverage

* test: add bot/pipeline advanced flows and cross-resource tests

- Bot enable/disable toggle with state persistence
- Bot detail tab switching (Configuration, Logs, Sessions)
- Bot form dirty state and save button behavior
- Bot name validation error display
- Pipeline tab switching (Configuration, Dashboard)
- Pipeline form dirty state
- Pipeline name validation error display
- Cross-resource flow: create pipeline then bind to bot
- Empty states for bots, pipelines, knowledge bases, MCP servers
2026-06-16 21:34:17 +08:00
RockChinQ
4e45886647 style(web): show Models above API Integration in main sidebar footer 2026-06-16 06:04:59 -04:00
RockChinQ
f592656680 refactor(web): unify settings panel layouts with shared toolbar/body
- Add PanelToolbar/PanelBody primitives so all four settings tabs share
  the same top-toolbar + scrollable-body rhythm under the unified header.
- API panel: drop the heavy gray shadowed TabsList; move the create
  action into the toolbar next to the tabs, lighten per-tab hints.
- Storage panel: reuse PanelToolbar for the generated-at/refresh bar.
- Account panel: wrap content in PanelBody for consistent padding.
- Models panel: keep the pinned LangBot Models (Space) card at the very
  top, above the add-custom-provider row (intentional pin), using
  PanelBody instead of a top toolbar.
2026-06-16 06:02:20 -04:00
RockChinQ
e9db858dcc feat(web): unified header for settings dialog, shorter sidebar labels
- Add a shared section header (icon + title + description) with right
  padding so the dialog close X no longer overlaps panel content, and
  every tab now shares the same top layout for a consistent look.
- Shorten inner sidebar nav labels (Models/API/Storage/Account) via new
  settingsDialog.nav.* i18n keys across all 8 locales.
- Add common.apiIntegrationDescription and account.settingsDescription
  for the new header.
2026-06-16 05:50:44 -04:00
RockChinQ
2d6faf9d5e refactor(web): drop legacy ModelsDialog, use unified SettingsDialog everywhere
The model-selector in dynamic forms (pipeline / knowledge base settings)
still opened the old standalone ModelsDialog. Point it at the unified
SettingsDialog (section pinned to models) and delete the now-unused
ModelsDialog wrapper so only the new dialog remains.
2026-06-16 05:41:58 -04:00
RockChinQ
d4699547e9 i18n(web): localize Bots/Pipelines sidebar titles for es/th/vi
es-ES pipelines, th-TH bots+pipelines and vi-VN pipelines were left in
English in the sidebar. Translate them: es Flujos, th บอท/ไปป์ไลน์,
vi Quy trình.
2026-06-16 05:27:10 -04:00
RockChinQ
716d7aca94 fix(web): fixed-height settings dialog, narrower sidebar
Pin the dialog to a fixed 80vh (cap 800px) so switching sections no
longer resizes it; panels scroll their own content internally. Override
the SidebarProvider wrapper's default h-svh with h-full so both columns
fill the dialog height. Narrow the inner settings sidebar to w-44.
2026-06-16 05:22:42 -04:00
RockChinQ
b3c00fe6da fix(web): use fixed height for settings dialog instead of 80vh
Avoid the dialog stretching to fill tall viewports (large empty space).
Pin to 620px with max-h-[85vh] fallback and narrow width to 52rem.
2026-06-16 05:18:14 -04:00
RockChinQ
f4a6edf7ec refactor(web): unify settings dialogs into single dialog with sidebar
Merge API integration, model settings, account settings and storage
analysis into one SettingsDialog with a shadcn inner sidebar for
section switching. Preserve existing ?action= query-param deep links
(showModelSettings / showAccountSettings / showApiIntegrationSettings /
showStorageAnalysis) by mapping each to a section. Extract reusable
panels and keep ModelsDialog as a thin wrapper for the dynamic-form
model picker.
2026-06-16 05:06:06 -04:00
huanghuoguoguo
f390980d0a test: format test suite (#2252) 2026-06-16 11:22:29 +08:00
huanghuoguoguo
1ae5aacc00 test: add frontend smoke and backend e2e CI (#2251) 2026-06-16 11:09:55 +08:00
158 changed files with 6125 additions and 3577 deletions

46
.github/workflows/frontend-tests.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Frontend Tests
on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
push:
branches:
- master
- develop
paths:
- 'web/**'
- '.github/workflows/frontend-tests.yml'
jobs:
playwright-smoke:
name: Playwright Smoke
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 8.9.2
- name: Install dependencies
working-directory: web
run: pnpm install --frozen-lockfile
- name: Install Playwright browsers
working-directory: web
run: pnpm exec playwright install --with-deps chromium
- name: Run Playwright smoke tests
working-directory: web
run: pnpm test:e2e

View File

@@ -29,7 +29,7 @@ jobs:
run: uv sync --dev
- name: Run ruff check
run: uv run ruff check src
run: uv run ruff check src/langbot/ tests/ --output-format=concise
- name: Run ruff format
run: uv run ruff format src --check

View File

@@ -84,6 +84,67 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
e2e:
name: E2E Startup Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run E2E startup tests
run: uv run pytest tests/e2e -q --tb=short
- name: E2E Test Summary
if: always()
run: |
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
box-integration:
name: Box Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Check Docker runtime
run: docker info
- name: Run Box integration tests
run: uv run pytest tests/integration_tests -q --tb=short
- name: Box Integration Test Summary
if: always()
run: |
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage:
name: Coverage Gate
runs-on: ubuntu-latest
@@ -129,4 +190,4 @@ jobs:
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

View File

@@ -313,18 +313,30 @@ class MonitoringRouterGroup(group.RouterGroup):
offset=0,
)
# Get traces
traces, traces_total = await self.ap.monitoring_service.get_traces(
bot_ids=bot_ids if bot_ids else None,
pipeline_ids=pipeline_ids if pipeline_ids else None,
start_time=start_time,
end_time=end_time,
limit=limit,
offset=0,
)
return self.success(
data={
'overview': overview,
'messages': messages,
'llmCalls': llm_calls,
'embeddingCalls': embedding_calls,
'traces': traces,
'sessions': sessions,
'errors': errors,
'totalCount': {
'messages': messages_total,
'llmCalls': llm_calls_total,
'embeddingCalls': embedding_calls_total,
'traces': traces_total,
'sessions': sessions_total,
'errors': errors_total,
},
@@ -350,6 +362,49 @@ class MonitoringRouterGroup(group.RouterGroup):
return self.success(data=details)
@self.route('/traces', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def get_traces() -> str:
"""Get end-to-end trace records."""
bot_ids = quart.request.args.getlist('botId')
pipeline_ids = quart.request.args.getlist('pipelineId')
session_ids = quart.request.args.getlist('sessionId')
statuses = quart.request.args.getlist('status')
start_time_str = quart.request.args.get('startTime')
end_time_str = quart.request.args.get('endTime')
limit = int(quart.request.args.get('limit', 100))
offset = int(quart.request.args.get('offset', 0))
start_time = parse_iso_datetime(start_time_str)
end_time = parse_iso_datetime(end_time_str)
traces, total = await self.ap.monitoring_service.get_traces(
bot_ids=bot_ids if bot_ids else None,
pipeline_ids=pipeline_ids if pipeline_ids else None,
session_ids=session_ids if session_ids else None,
statuses=statuses if statuses else None,
start_time=start_time,
end_time=end_time,
limit=limit,
offset=offset,
)
return self.success(
data={
'traces': traces,
'total': total,
'limit': limit,
'offset': offset,
}
)
@self.route('/traces/<trace_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def get_trace_details(trace_id: str) -> str:
"""Get one trace with all spans."""
details = await self.ap.monitoring_service.get_trace_details(trace_id)
if not details.get('found'):
return self.http_status(404, -1, f'Trace {trace_id} not found')
return self.success(data=details)
@self.route('/export', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def export_data() -> tuple[str, int]:
"""Export monitoring data as CSV"""

View File

@@ -350,8 +350,24 @@ class PluginsRouterGroup(group.RouterGroup):
if not endpoint.startswith('/') or '..' in endpoint:
return self.http_status(400, -1, 'invalid endpoint')
caller = {
'plugin_author': author,
'plugin_name': plugin_name,
'page_id': page_id,
'origin': _get_request_origin(),
}
headers = {
key: value
for key, value in {
'user-agent': quart.request.headers.get('User-Agent'),
'x-request-id': quart.request.headers.get('X-Request-ID'),
'x-forwarded-for': quart.request.headers.get('X-Forwarded-For'),
}.items()
if value
}
result = await self.ap.plugin_connector.handle_page_api(
author, plugin_name, page_id, endpoint, method.upper(), body
author, plugin_name, page_id, endpoint, method.upper(), body, caller, headers
)
if result.get('error'):
return self.http_status(400, -1, result['error'])

View File

@@ -3,11 +3,53 @@ from __future__ import annotations
import uuid
import datetime
import sqlalchemy
import json
from ....core import app
from ....entity.persistence import monitoring as persistence_monitoring
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
def _json_dumps(value: dict | list | None) -> str | None:
if value is None:
return None
try:
return json.dumps(value, ensure_ascii=False, default=str)
except Exception:
return json.dumps({'serialization_error': str(value)}, ensure_ascii=False)
def _json_loads(value: str | None) -> dict | list | None:
if not value:
return None
try:
return json.loads(value)
except Exception:
return None
def new_trace_id() -> str:
return f'trace-{uuid.uuid4().hex[:16]}'
def new_span_id() -> str:
return f'span-{uuid.uuid4().hex[:16]}'
def normalize_trace_status(status: str | None) -> str:
"""Normalize operation status to the monitoring UI vocabulary."""
if status in ('completed', 'ok'):
return 'success'
if status in ('failed', 'failure', 'exception'):
return 'error'
if status in ('running', 'success', 'error'):
return status
return 'success'
class MonitoringService:
"""Monitoring service"""
@@ -74,6 +116,18 @@ class MonitoringService:
persistence_monitoring.MonitoringFeedback.timestamp,
persistence_monitoring.MonitoringFeedback.id,
),
(
'monitoring_traces',
persistence_monitoring.MonitoringTrace,
persistence_monitoring.MonitoringTrace.started_at,
persistence_monitoring.MonitoringTrace.trace_id,
),
(
'monitoring_spans',
persistence_monitoring.MonitoringSpan,
persistence_monitoring.MonitoringSpan.started_at,
persistence_monitoring.MonitoringSpan.span_id,
),
]
deleted_counts: dict[str, int] = {}
@@ -133,6 +187,116 @@ class MonitoringService:
# ========== Recording Methods ==========
async def start_trace(
self,
trace_id: str | None = None,
name: str = 'LangBot query',
bot_id: str | None = None,
bot_name: str | None = None,
pipeline_id: str | None = None,
pipeline_name: str | None = None,
session_id: str | None = None,
message_id: str | None = None,
query_id: str | int | None = None,
attributes: dict | None = None,
) -> str:
"""Create or update a trace header row."""
trace_id = trace_id or new_trace_id()
trace_data = {
'trace_id': trace_id,
'started_at': _utc_now(),
'ended_at': None,
'duration': None,
'status': 'running',
'name': name,
'bot_id': bot_id,
'bot_name': bot_name,
'pipeline_id': pipeline_id,
'pipeline_name': pipeline_name,
'session_id': session_id,
'message_id': message_id,
'query_id': str(query_id) if query_id is not None else None,
'attributes': _json_dumps(attributes),
}
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_monitoring.MonitoringTrace).values(trace_data)
)
return trace_id
async def finish_trace(
self,
trace_id: str,
status: str = 'success',
duration: int | None = None,
message_id: str | None = None,
attributes: dict | None = None,
) -> None:
"""Mark a trace complete."""
update_values: dict = {
'ended_at': _utc_now(),
'status': normalize_trace_status(status),
}
if duration is not None:
update_values['duration'] = duration
if message_id is not None:
update_values['message_id'] = message_id
if attributes is not None:
update_values['attributes'] = _json_dumps(attributes)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_monitoring.MonitoringTrace)
.where(persistence_monitoring.MonitoringTrace.trace_id == trace_id)
.values(update_values)
)
async def record_span(
self,
trace_id: str,
name: str,
kind: str,
status: str = 'success',
span_id: str | None = None,
parent_span_id: str | None = None,
started_at: datetime.datetime | None = None,
ended_at: datetime.datetime | None = None,
duration: int | None = None,
message_id: str | None = None,
session_id: str | None = None,
bot_id: str | None = None,
pipeline_id: str | None = None,
attributes: dict | None = None,
error_message: str | None = None,
) -> str:
"""Record a single completed span."""
started_at = started_at or _utc_now()
if duration is None and ended_at is not None:
duration = int((ended_at - started_at).total_seconds() * 1000)
elif duration is not None:
duration = int(round(float(duration)))
span_data = {
'span_id': span_id or new_span_id(),
'trace_id': trace_id,
'parent_span_id': parent_span_id,
'name': name,
'kind': kind,
'status': normalize_trace_status(status),
'started_at': started_at,
'ended_at': ended_at or _utc_now(),
'duration': duration,
'message_id': message_id,
'session_id': session_id,
'bot_id': bot_id,
'pipeline_id': pipeline_id,
'attributes': _json_dumps(attributes),
'error_message': error_message,
}
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_monitoring.MonitoringSpan).values(span_data)
)
return span_data['span_id']
async def record_message(
self,
bot_id: str,
@@ -1076,6 +1240,19 @@ class MonitoringService:
for row in error_rows
]
trace_query = (
sqlalchemy.select(persistence_monitoring.MonitoringTrace)
.where(persistence_monitoring.MonitoringTrace.message_id == message_id)
.order_by(persistence_monitoring.MonitoringTrace.started_at.desc())
.limit(1)
)
trace_result = await self.ap.persistence_mgr.execute_async(trace_query)
trace_row = trace_result.first()
trace = None
if trace_row:
trace_model = trace_row[0] if isinstance(trace_row, tuple) else trace_row
trace = self._serialize_trace(trace_model)
return {
'message_id': message_id,
'found': True,
@@ -1090,6 +1267,90 @@ class MonitoringService:
'average_duration_ms': int(total_duration / len(llm_rows)) if len(llm_rows) > 0 else 0,
},
'errors': errors,
'trace': trace,
}
def _serialize_trace(self, trace: persistence_monitoring.MonitoringTrace) -> dict:
data = self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringTrace, trace)
data['attributes'] = _json_loads(data.get('attributes')) or {}
return data
def _serialize_span(self, span: persistence_monitoring.MonitoringSpan) -> dict:
data = self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringSpan, span)
data['attributes'] = _json_loads(data.get('attributes')) or {}
return data
async def get_traces(
self,
bot_ids: list[str] | None = None,
pipeline_ids: list[str] | None = None,
session_ids: list[str] | None = None,
statuses: list[str] | None = None,
start_time: datetime.datetime | None = None,
end_time: datetime.datetime | None = None,
limit: int = 100,
offset: int = 0,
) -> tuple[list[dict], int]:
"""Get trace headers with filters."""
conditions = []
if bot_ids:
conditions.append(persistence_monitoring.MonitoringTrace.bot_id.in_(bot_ids))
if pipeline_ids:
conditions.append(persistence_monitoring.MonitoringTrace.pipeline_id.in_(pipeline_ids))
if session_ids:
conditions.append(persistence_monitoring.MonitoringTrace.session_id.in_(session_ids))
if statuses:
conditions.append(persistence_monitoring.MonitoringTrace.status.in_(statuses))
if start_time:
conditions.append(persistence_monitoring.MonitoringTrace.started_at >= start_time)
if end_time:
conditions.append(persistence_monitoring.MonitoringTrace.started_at <= end_time)
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringTrace.trace_id))
query = sqlalchemy.select(persistence_monitoring.MonitoringTrace)
if conditions:
clause = sqlalchemy.and_(*conditions)
count_query = count_query.where(clause)
query = query.where(clause)
total_result = await self.ap.persistence_mgr.execute_async(count_query)
total = total_result.scalar() or 0
query = query.order_by(persistence_monitoring.MonitoringTrace.started_at.desc()).limit(limit).offset(offset)
result = await self.ap.persistence_mgr.execute_async(query)
traces = [
self._serialize_trace(row[0] if isinstance(row, tuple) else row)
for row in result.all()
]
return traces, total
async def get_trace_details(self, trace_id: str) -> dict:
"""Get a single trace and all spans in chronological order."""
trace_query = sqlalchemy.select(persistence_monitoring.MonitoringTrace).where(
persistence_monitoring.MonitoringTrace.trace_id == trace_id
)
trace_result = await self.ap.persistence_mgr.execute_async(trace_query)
trace_row = trace_result.first()
if not trace_row:
return {'trace_id': trace_id, 'found': False}
trace = trace_row[0] if isinstance(trace_row, tuple) else trace_row
span_query = (
sqlalchemy.select(persistence_monitoring.MonitoringSpan)
.where(persistence_monitoring.MonitoringSpan.trace_id == trace_id)
.order_by(persistence_monitoring.MonitoringSpan.started_at.asc())
)
span_result = await self.ap.persistence_mgr.execute_async(span_query)
spans = [
self._serialize_span(row[0] if isinstance(row, tuple) else row)
for row in span_result.all()
]
return {
'trace_id': trace_id,
'found': True,
'trace': self._serialize_trace(trace),
'spans': spans,
}
# ========== Export Methods ==========

View File

@@ -146,19 +146,13 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
_LB_PIP_CACHE_DIR="{mount_path}/.cache/pip"
mkdir -p "$_LB_META_DIR" "$_LB_TMP_DIR" "$_LB_PIP_CACHE_DIR"
_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"
if [ -z "$_LB_SYSTEM_PYTHON" ]; then
echo "python3 or python is required to prepare the workspace Python environment" >&2
exit 127
fi
export TMPDIR="$_LB_TMP_DIR"
export TEMP="$_LB_TMP_DIR"
export TMP="$_LB_TMP_DIR"
export PIP_CACHE_DIR="$_LB_PIP_CACHE_DIR"
_lb_python_meta() {{
"$_LB_SYSTEM_PYTHON" - <<'PY'
python - <<'PY'
import hashlib
import json
import os
@@ -207,26 +201,15 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
_LB_LOCK_WAIT=0
while ! mkdir "$_LB_LOCK_DIR" 2>/dev/null; do
if [ "$_LB_LOCK_WAIT" -ge 120 ]; then
_LB_LOCK_OWNER="$(cat "$_LB_LOCK_DIR/pid" 2>/dev/null || true)"
if [ -n "$_LB_LOCK_OWNER" ] && kill -0 "$_LB_LOCK_OWNER" 2>/dev/null; then
echo "Timed out waiting for active Python environment lock: $_LB_LOCK_DIR" >&2
exit 1
fi
echo "Timed out waiting for Python environment lock, clearing stale lock: $_LB_LOCK_DIR" >&2
rm -rf "$_LB_LOCK_DIR" 2>/dev/null || true
if mkdir "$_LB_LOCK_DIR" 2>/dev/null; then
break
fi
echo "Timed out waiting for Python environment lock: $_LB_LOCK_DIR" >&2
exit 1
fi
sleep 1
_LB_LOCK_WAIT=$((_LB_LOCK_WAIT + 1))
done
printf '%s\\n' "$$" > "$_LB_LOCK_DIR/pid" 2>/dev/null || true
_lb_cleanup_lock() {{
rm -rf "$_LB_LOCK_DIR" >/dev/null 2>&1 || true
rmdir "$_LB_LOCK_DIR" >/dev/null 2>&1 || true
}}
trap _lb_cleanup_lock EXIT INT TERM
@@ -242,7 +225,7 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
if [ "$_LB_NEEDS_BOOTSTRAP" -eq 1 ]; then
rm -rf "$_LB_VENV_DIR"
"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"
python -m venv "$_LB_VENV_DIR"
. "$_LB_VENV_DIR/bin/activate"
python -m pip install --upgrade pip setuptools wheel
if [ -f "{mount_path}/requirements.txt" ]; then

View File

@@ -3,6 +3,49 @@ import sqlalchemy
from .base import Base
class MonitoringTrace(Base):
"""End-to-end monitoring trace records"""
__tablename__ = 'monitoring_traces'
trace_id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
ended_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True)
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) # milliseconds
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True) # running, success, error
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
query_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
attributes = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
class MonitoringSpan(Base):
"""Trace span records for pipeline, RAG, model, plugin and tool operations"""
__tablename__ = 'monitoring_spans'
span_id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
trace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
parent_span_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
kind = sqlalchemy.Column(sqlalchemy.String(80), nullable=False, index=True)
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True)
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
ended_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) # milliseconds
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
attributes = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
error_message = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
class MonitoringMessage(Base):
"""Monitoring message records"""

View File

@@ -0,0 +1,88 @@
"""add monitoring traces and spans
Revision ID: 0006_monitoring_traces
Revises: 0005_add_llm_context_length
Create Date: 2026-06-16
"""
import sqlalchemy as sa
from alembic import op
revision = '0006_monitoring_traces'
down_revision = '0005_add_llm_context_length'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
tables = set(inspector.get_table_names())
if 'monitoring_traces' not in tables:
op.create_table(
'monitoring_traces',
sa.Column('trace_id', sa.String(length=255), nullable=False),
sa.Column('started_at', sa.DateTime(), nullable=False),
sa.Column('ended_at', sa.DateTime(), nullable=True),
sa.Column('duration', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('bot_id', sa.String(length=255), nullable=True),
sa.Column('bot_name', sa.String(length=255), nullable=True),
sa.Column('pipeline_id', sa.String(length=255), nullable=True),
sa.Column('pipeline_name', sa.String(length=255), nullable=True),
sa.Column('session_id', sa.String(length=255), nullable=True),
sa.Column('message_id', sa.String(length=255), nullable=True),
sa.Column('query_id', sa.String(length=255), nullable=True),
sa.Column('attributes', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('trace_id'),
)
op.create_index('ix_monitoring_traces_started_at', 'monitoring_traces', ['started_at'])
op.create_index('ix_monitoring_traces_ended_at', 'monitoring_traces', ['ended_at'])
op.create_index('ix_monitoring_traces_status', 'monitoring_traces', ['status'])
op.create_index('ix_monitoring_traces_bot_id', 'monitoring_traces', ['bot_id'])
op.create_index('ix_monitoring_traces_pipeline_id', 'monitoring_traces', ['pipeline_id'])
op.create_index('ix_monitoring_traces_session_id', 'monitoring_traces', ['session_id'])
op.create_index('ix_monitoring_traces_message_id', 'monitoring_traces', ['message_id'])
op.create_index('ix_monitoring_traces_query_id', 'monitoring_traces', ['query_id'])
if 'monitoring_spans' not in tables:
op.create_table(
'monitoring_spans',
sa.Column('span_id', sa.String(length=255), nullable=False),
sa.Column('trace_id', sa.String(length=255), nullable=False),
sa.Column('parent_span_id', sa.String(length=255), nullable=True),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('kind', sa.String(length=80), nullable=False),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('started_at', sa.DateTime(), nullable=False),
sa.Column('ended_at', sa.DateTime(), nullable=True),
sa.Column('duration', sa.Integer(), nullable=True),
sa.Column('message_id', sa.String(length=255), nullable=True),
sa.Column('session_id', sa.String(length=255), nullable=True),
sa.Column('bot_id', sa.String(length=255), nullable=True),
sa.Column('pipeline_id', sa.String(length=255), nullable=True),
sa.Column('attributes', sa.Text(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('span_id'),
)
op.create_index('ix_monitoring_spans_trace_id', 'monitoring_spans', ['trace_id'])
op.create_index('ix_monitoring_spans_parent_span_id', 'monitoring_spans', ['parent_span_id'])
op.create_index('ix_monitoring_spans_kind', 'monitoring_spans', ['kind'])
op.create_index('ix_monitoring_spans_status', 'monitoring_spans', ['status'])
op.create_index('ix_monitoring_spans_started_at', 'monitoring_spans', ['started_at'])
op.create_index('ix_monitoring_spans_message_id', 'monitoring_spans', ['message_id'])
op.create_index('ix_monitoring_spans_session_id', 'monitoring_spans', ['session_id'])
op.create_index('ix_monitoring_spans_bot_id', 'monitoring_spans', ['bot_id'])
op.create_index('ix_monitoring_spans_pipeline_id', 'monitoring_spans', ['pipeline_id'])
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
tables = set(inspector.get_table_names())
if 'monitoring_spans' in tables:
op.drop_table('monitoring_spans')
if 'monitoring_traces' in tables:
op.drop_table('monitoring_traces')

View File

@@ -2,6 +2,9 @@ from __future__ import annotations
import typing
import traceback
import time
import uuid
import datetime
import sqlalchemy
@@ -79,6 +82,19 @@ class RuntimePipeline:
enable_all_plugins: bool
"""是否启用所有插件"""
@staticmethod
def _new_span_id() -> str:
return f'span-{uuid.uuid4().hex[:16]}'
@staticmethod
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
@staticmethod
def _query_session_id(query: pipeline_query.Query) -> str:
launcher_type = query.launcher_type.value if hasattr(query.launcher_type, 'value') else str(query.launcher_type)
return f'{launcher_type}_{query.launcher_id}'
enable_all_mcp_servers: bool
"""是否启用所有MCP服务器"""
@@ -234,44 +250,92 @@ class RuntimePipeline:
stage_container = self.stage_containers[i]
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
span_started_at = self._utc_now()
span_started = time.perf_counter()
span_status = 'success'
span_error = None
span_result_type = None
result = stage_container.inst.process(query, stage_container.inst_name)
try:
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
)
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
async for sub_result in result:
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
span_result_type = str(result.result_type.value if hasattr(result.result_type, 'value') else result.result_type)
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
)
await self._check_output(query, sub_result)
await self._check_output(query, result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
span_result_type = 'generator'
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
async for sub_result in result:
span_result_type = str(
sub_result.result_type.value
if hasattr(sub_result.result_type, 'value')
else sub_result.result_type
)
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
)
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
except Exception as e:
span_status = 'error'
span_error = str(e)
raise
finally:
trace_id = (query.variables or {}).get('_monitoring_trace_id')
root_span_id = (query.variables or {}).get('_monitoring_root_span_id')
if trace_id:
try:
await self.ap.monitoring_service.record_span(
trace_id=trace_id,
parent_span_id=root_span_id,
name=stage_container.inst_name,
kind='pipeline.stage',
status=span_status,
started_at=span_started_at,
duration=int((time.perf_counter() - span_started) * 1000),
message_id=(query.variables or {}).get('_monitoring_message_id'),
session_id=self._query_session_id(query),
bot_id=query.bot_uuid,
pipeline_id=self.pipeline_entity.uuid,
attributes={
'stage_class': stage_container.inst.__class__.__name__,
'result_type': span_result_type,
'query_id': query.query_id,
},
error_message=span_error,
)
except Exception as monitor_err:
self.ap.logger.error(f'Failed to record stage span: {monitor_err}')
i += 1
async def process_query(self, query: pipeline_query.Query):
"""处理请求"""
trace_started_at = self._utc_now()
trace_started = time.perf_counter()
root_span_id = self._new_span_id()
trace_id = None
trace_status = 'success'
# Get monitoring metadata
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
@@ -303,6 +367,28 @@ class RuntimePipeline:
except Exception as e:
self.ap.logger.error(f'Failed to record query start: {e}')
try:
trace_id = await self.ap.monitoring_service.start_trace(
name='LangBot query',
bot_id=query.bot_uuid or 'unknown',
bot_name=bot_name,
pipeline_id=self.pipeline_entity.uuid,
pipeline_name=pipeline_name,
session_id=self._query_session_id(query),
message_id=message_id or None,
query_id=query.query_id,
attributes={
'launcher_type': query.launcher_type.value
if hasattr(query.launcher_type, 'value')
else str(query.launcher_type),
'runner_name': runner_name,
},
)
query.variables['_monitoring_trace_id'] = trace_id
query.variables['_monitoring_root_span_id'] = root_span_id
except Exception as e:
self.ap.logger.error(f'Failed to start query trace: {e}')
try:
# Get bound plugins for this pipeline
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
@@ -361,6 +447,7 @@ class RuntimePipeline:
self.ap.logger.error(f'Failed to record query response: {e}')
except Exception as e:
trace_status = 'error'
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
@@ -383,6 +470,35 @@ class RuntimePipeline:
self.ap.logger.error(f'Failed to record query error: {me}')
finally:
if trace_id:
try:
duration_ms = int((time.perf_counter() - trace_started) * 1000)
await self.ap.monitoring_service.record_span(
trace_id=trace_id,
span_id=root_span_id,
name='LangBot query',
kind='pipeline.query',
status=trace_status,
started_at=trace_started_at,
duration=duration_ms,
message_id=message_id or None,
session_id=self._query_session_id(query),
bot_id=query.bot_uuid,
pipeline_id=self.pipeline_entity.uuid,
attributes={
'query_id': query.query_id,
'pipeline_name': pipeline_name,
'runner_name': runner_name,
},
)
await self.ap.monitoring_service.finish_trace(
trace_id=trace_id,
status=trace_status,
duration=duration_ms,
message_id=message_id or None,
)
except Exception as monitor_err:
self.ap.logger.error(f'Failed to finish query trace: {monitor_err}')
self.ap.logger.debug(f'Query {query.query_id} processed')
del self.ap.query_pool.cached_queries[query.query_id]

View File

@@ -711,8 +711,19 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
endpoint: str,
method: str,
body: Any = None,
caller: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
return await self.handler.handle_page_api(plugin_author, plugin_name, page_id, endpoint, method, body)
return await self.handler.handle_page_api(
plugin_author,
plugin_name,
page_id,
endpoint,
method,
body,
caller,
headers or {},
)
async def get_debug_info(self) -> dict[str, Any]:
"""Get debug information including debug key and WS URL"""

View File

@@ -755,6 +755,19 @@ class RuntimeConnectionHandler(handler.Handler):
'session_name': session_name,
'bot_uuid': query.bot_uuid or '',
'sender_id': str(query.sender_id),
'_trace_context': {
'trace_id': query.variables.get('_monitoring_trace_id') if query.variables else None,
'parent_span_id': query.variables.get('_monitoring_root_span_id') if query.variables else None,
'message_id': query.variables.get('_monitoring_message_id') if query.variables else None,
'query_id': query.query_id,
'session_id': session_name,
'bot_id': query.bot_uuid or '',
'pipeline_id': query.pipeline_uuid or '',
'knowledge_base_id': kb_id,
'attributes': {
'source': 'plugin-api',
},
},
},
)
results = [entry.model_dump(mode='json') for entry in entries]
@@ -1011,6 +1024,8 @@ class RuntimeConnectionHandler(handler.Handler):
endpoint: str,
method: str,
body: Any = None,
caller: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Forward a page API call to the plugin via runtime."""
result = await self.call_action(
@@ -1022,6 +1037,8 @@ class RuntimeConnectionHandler(handler.Handler):
'endpoint': endpoint,
'method': method,
'body': body,
'caller': caller,
'headers': headers or {},
},
timeout=30,
)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import abc
import typing
import time
import datetime
from ...core import app
from ...entity.persistence import model as persistence_model
@@ -16,6 +17,15 @@ LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
def _query_session_id(query: pipeline_query.Query) -> str:
launcher_type = query.launcher_type.value if hasattr(query.launcher_type, 'value') else str(query.launcher_type)
return f'{launcher_type}_{query.launcher_id}'
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
"""Store the latest provider usage on the query for upstream action handlers."""
if query is None or not usage_info:
@@ -59,6 +69,7 @@ class RuntimeProvider:
"""Bridge method for invoking LLM with monitoring"""
# Start timing for monitoring
start_time = time.time()
span_started_at = _utc_now()
input_tokens = 0
output_tokens = 0
status = 'success'
@@ -125,6 +136,30 @@ class RuntimeProvider:
error_message=error_message,
message_id=message_id,
)
trace_id = query.variables.get('_monitoring_trace_id') if query.variables else None
parent_span_id = query.variables.get('_monitoring_root_span_id') if query.variables else None
if trace_id:
await self.requester.ap.monitoring_service.record_span(
trace_id=trace_id,
parent_span_id=parent_span_id,
name=f'LLM {model.model_entity.name}',
kind='model.llm',
status=status,
started_at=span_started_at,
duration=duration_ms,
message_id=message_id,
session_id=_query_session_id(query),
bot_id=query.bot_uuid,
pipeline_id=query.pipeline_uuid,
attributes={
'model_name': model.model_entity.name,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
'stream': False,
},
error_message=error_message,
)
except Exception as monitor_err:
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM call: {monitor_err}')
@@ -140,6 +175,7 @@ class RuntimeProvider:
"""Bridge method for invoking LLM stream with monitoring"""
# Start timing for monitoring
start_time = time.time()
span_started_at = _utc_now()
status = 'success'
error_message = None
input_tokens = 0
@@ -204,6 +240,30 @@ class RuntimeProvider:
error_message=error_message,
message_id=message_id,
)
trace_id = query.variables.get('_monitoring_trace_id') if query.variables else None
parent_span_id = query.variables.get('_monitoring_root_span_id') if query.variables else None
if trace_id:
await self.requester.ap.monitoring_service.record_span(
trace_id=trace_id,
parent_span_id=parent_span_id,
name=f'LLM stream {model.model_entity.name}',
kind='model.llm',
status=status,
started_at=span_started_at,
duration=duration_ms,
message_id=message_id,
session_id=_query_session_id(query),
bot_id=query.bot_uuid,
pipeline_id=query.pipeline_uuid,
attributes={
'model_name': model.model_entity.name,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
'stream': True,
},
error_message=error_message,
)
except Exception as monitor_err:
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM stream call: {monitor_err}')

View File

@@ -268,6 +268,19 @@ class LocalAgentRunner(runner.RequestRunner):
'bot_uuid': query.bot_uuid or '',
'sender_id': str(query.sender_id),
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'_trace_context': {
'trace_id': query.variables.get('_monitoring_trace_id') if query.variables else None,
'parent_span_id': query.variables.get('_monitoring_root_span_id') if query.variables else None,
'message_id': query.variables.get('_monitoring_message_id') if query.variables else None,
'query_id': query.query_id,
'session_id': f'{query.launcher_type.value}_{query.launcher_id}',
'bot_id': query.bot_uuid or '',
'pipeline_id': query.pipeline_uuid or '',
'knowledge_base_id': kb_uuid,
'attributes': {
'source': 'local-agent',
},
},
},
)

View File

@@ -1,18 +0,0 @@
from __future__ import annotations
from typing import Any
async def is_box_backend_available(ap: Any) -> bool:
"""Return whether the configured Box backend is ready for tool execution."""
box_service = getattr(ap, 'box_service', None)
if box_service is None:
return False
if not getattr(box_service, 'available', False):
return False
try:
status = await box_service.get_status()
backend_info = status.get('backend', {})
return bool(backend_info.get('available', False))
except Exception:
return False

View File

@@ -5,8 +5,6 @@ import asyncio
import os
import shutil
import shlex
import threading
from contextlib import suppress
from typing import TYPE_CHECKING, Any
import pydantic
@@ -20,26 +18,12 @@ from ....box.workspace import (
rewrite_mounted_path,
rewrite_venv_command,
unwrap_venv_path,
wrap_python_command_with_env,
)
if TYPE_CHECKING:
from .mcp import RuntimeMCPSession
_WORKSPACE_COPY_LOCKS: dict[str, threading.Lock] = {}
_WORKSPACE_COPY_LOCKS_GUARD = threading.Lock()
def _workspace_copy_lock(path: str) -> threading.Lock:
with _WORKSPACE_COPY_LOCKS_GUARD:
lock = _WORKSPACE_COPY_LOCKS.get(path)
if lock is None:
lock = threading.Lock()
_WORKSPACE_COPY_LOCKS[path] = lock
return lock
class MCPSessionErrorPhase(enum.Enum):
"""Which phase of the MCP lifecycle failed."""
@@ -65,7 +49,7 @@ class MCPServerBoxConfig(pydantic.BaseModel):
host_path: str | None = None
host_path_mode: str = 'ro' # MCP servers default to read-write mount only when explicitly requested
env: dict[str, str] = pydantic.Field(default_factory=dict)
startup_timeout_sec: int = 300 # First Docker bootstrap may need to build a venv and install MCP deps.
startup_timeout_sec: int = 120 # Longer default to allow dependency bootstrap
cpus: float | None = None
memory_mb: int | None = None
pids_limit: int | None = None
@@ -144,7 +128,6 @@ class BoxStdioSessionRuntime:
workspace = self._build_workspace(host_path=None)
host_path = self.resolve_host_path()
process_cwd = '/workspace'
install_cmd: str | None = None
try:
await workspace.create_session()
@@ -185,8 +168,6 @@ class BoxStdioSessionRuntime:
env=self.server_config.get('env', {}),
cwd=process_cwd,
)
if install_cmd:
payload = self._wrap_process_payload_with_python_env(payload, process_cwd)
payload['process_id'] = self.process_id
await workspace.box_service.start_managed_process(workspace.session_id, payload)
except Exception:
@@ -272,42 +253,14 @@ class BoxStdioSessionRuntime:
@staticmethod
def _copy_workspace_tree(source_path: str, process_host_root: str, process_host_workspace: str) -> None:
# Docker-backed bootstrap writes root-owned runtime directories such as
# .venv/.tmp into the staged workspace. The host process may not be able
# to delete them, so refresh source files in place and preserve runtime
# directories instead of rmtree'ing the whole staging root.
with _workspace_copy_lock(process_host_root):
preserved_names = {'.venv', 'venv', 'env', '.cache', '.tmp', '.langbot'}
os.makedirs(process_host_workspace, exist_ok=True)
for name in os.listdir(process_host_workspace):
if name in preserved_names:
continue
path = os.path.join(process_host_workspace, name)
if os.path.isdir(path) and not os.path.islink(path):
shutil.rmtree(path, ignore_errors=True)
else:
# The entry may disappear between listdir and unlink if cleanup races us.
with suppress(FileNotFoundError):
os.unlink(path)
shutil.copytree(
source_path,
process_host_workspace,
symlinks=True,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns(
'.git',
'__pycache__',
'.pytest_cache',
'.mypy_cache',
'.ruff_cache',
'.venv',
'venv',
'env',
'.cache',
'.tmp',
'.langbot',
),
)
shutil.rmtree(process_host_root, ignore_errors=True)
os.makedirs(process_host_root, exist_ok=True)
shutil.copytree(
source_path,
process_host_workspace,
symlinks=True,
ignore=shutil.ignore_patterns('.git', '__pycache__', '.pytest_cache', '.mypy_cache', '.ruff_cache'),
)
async def _cleanup_staged_workspace(self) -> None:
if not self.resolve_host_path():
@@ -390,25 +343,23 @@ class BoxStdioSessionRuntime:
@staticmethod
def detect_install_command(host_path: str, workspace_path: str = '/workspace') -> str | None:
workspace_kind = classify_python_workspace(host_path)
if workspace_kind in {'package', 'requirements'}:
return wrap_python_command_with_env('python -c "pass"', mount_path=workspace_path).rstrip()
quoted_workspace_path = shlex.quote(workspace_path)
if workspace_kind == 'package':
return (
'mkdir -p /opt/_lb_src'
f' && tar -C {quoted_workspace_path}'
' --exclude=.venv --exclude=.git --exclude=__pycache__'
' --exclude=node_modules --exclude=.tox --exclude=.nox'
' --exclude="*.egg-info" --exclude=.uv-cache'
' -cf - .'
' | tar -C /opt/_lb_src -xf -'
' && pip install --no-cache-dir /opt/_lb_src'
' && rm -rf /opt/_lb_src'
)
if workspace_kind == 'requirements':
return f'pip install --no-cache-dir -r {quoted_workspace_path}/requirements.txt'
return None
@staticmethod
def _wrap_process_payload_with_python_env(payload: dict[str, Any], workspace_path: str) -> dict[str, Any]:
"""Start a prepared Python workspace without writing bootstrap output to MCP stdio."""
workspace_root = workspace_path.rstrip('/') or '/workspace'
venv_dir = f'{workspace_root}/.venv'
venv_bin = f'{venv_dir}/bin'
command = ' '.join([shlex.quote(payload['command']), *[shlex.quote(arg) for arg in payload.get('args', [])]])
wrapped = dict(payload)
wrapped['command'] = 'sh'
wrapped['args'] = [
'-lc',
(f'export VIRTUAL_ENV={shlex.quote(venv_dir)}; export PATH={shlex.quote(venv_bin)}:$PATH; exec {command}'),
]
return wrapped
def build_box_session_payload(self, session_id: str, host_path: str | None = None) -> dict[str, Any]:
workspace = self._build_workspace()
workspace.session_id = session_id

View File

@@ -8,7 +8,6 @@ from langbot_plugin.api.entities.events import pipeline_query
from .. import loader
from ..errors import ToolNotFoundError
from .availability import is_box_backend_available
from . import skill as skill_loader
EXEC_TOOL_NAME = 'exec'
@@ -23,15 +22,6 @@ _ALL_TOOL_NAMES = {EXEC_TOOL_NAME, READ_TOOL_NAME, WRITE_TOOL_NAME, EDIT_TOOL_NA
# Skip these dirs during grep walk to avoid noise
_SKIP_DIRS = {'.git', 'node_modules', '__pycache__', '.venv', 'venv', '.tox', 'dist', 'build'}
_DEFAULT_READ_MAX_LINES = 2000
_MAX_READ_MAX_LINES = 10000
_DEFAULT_TOOL_RESULT_MAX_BYTES = 50 * 1024
_BOX_FILE_SCRIPT_MAX_BYTES = 2048
_GLOB_MAX_MATCHES = 100
_GREP_MAX_MATCHES = 200
_GREP_MAX_FILES = 5000
_GREP_MAX_LINE_CHARS = 500
class NativeToolLoader(loader.ToolLoader):
def __init__(self, ap):
@@ -53,7 +43,18 @@ class NativeToolLoader(loader.ToolLoader):
async def _check_backend_available(self) -> bool:
"""Check if the box backend is truly available (not just the runtime)."""
return await is_box_backend_available(self.ap)
box_service = getattr(self.ap, 'box_service', None)
if box_service is None:
return False
if not getattr(box_service, 'available', False):
return False
# Check if backend is truly available via get_status
try:
status = await box_service.get_status()
backend_info = status.get('backend', {})
return backend_info.get('available', False)
except Exception:
return False
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
if not self._is_sandbox_available():
@@ -138,7 +139,6 @@ class NativeToolLoader(loader.ToolLoader):
# via execute_tool. Skills are mounted at /workspace/.skills/{name}/
# via extra_mounts built by BoxService.
result = await self.ap.box_service.execute_tool(parameters, query)
result = self._normalize_exec_result(result)
if selected_skill is not None:
self._refresh_skill_from_disk(selected_skill)
@@ -227,65 +227,19 @@ class NativeToolLoader(loader.ToolLoader):
except Exception:
return {'ok': False, 'error': stdout or 'Box file operation returned no result'}
async def _read_workspace_via_box(self, path: str, parameters: dict, query: pipeline_query.Query) -> dict:
offset = self._positive_int(parameters.get('offset'), default=1)
max_lines = self._positive_int(
parameters.get('limit'),
default=_DEFAULT_READ_MAX_LINES,
max_value=_MAX_READ_MAX_LINES,
)
# Box file fallback returns through exec stdout, which is already capped
# by BoxService. Keep this payload small enough to remain valid JSON.
max_bytes = min(
self._positive_int(parameters.get('max_bytes'), default=_DEFAULT_TOOL_RESULT_MAX_BYTES),
_BOX_FILE_SCRIPT_MAX_BYTES,
)
async def _read_workspace_via_box(self, path: str, query: pipeline_query.Query) -> dict:
script = f"""
import json, os
path = {json.dumps(path)}
offset = {offset}
max_lines = {max_lines}
max_bytes = {max_bytes}
if not path.startswith('/workspace'):
print(json.dumps({{'ok': False, 'error': 'Path must be under /workspace.'}}))
elif not os.path.exists(path):
print(json.dumps({{'ok': False, 'error': f'File not found: {{path}}'}}))
elif os.path.isdir(path):
entries = sorted(os.listdir(path))
content = '\\n'.join(entries)
print(json.dumps({{'ok': True, 'content': content, 'is_directory': True, 'total': len(entries), 'truncated': False}}))
print(json.dumps({{'ok': True, 'content': '\\n'.join(sorted(os.listdir(path))), 'is_directory': True}}))
else:
lines = []
output_bytes = 0
end_line = offset - 1
truncated = False
next_offset = None
with open(path, 'r', encoding='utf-8', errors='replace') as f:
for line_number, line in enumerate(f, 1):
if line_number < offset:
continue
if len(lines) >= max_lines:
truncated = True
next_offset = line_number
break
line_bytes = len(line.encode('utf-8'))
if output_bytes + line_bytes > max_bytes:
truncated = True
next_offset = line_number
break
lines.append(line.rstrip('\\n'))
output_bytes += line_bytes
end_line = line_number
print(json.dumps({{
'ok': True,
'content': '\\n'.join(lines),
'truncated': truncated,
'start_line': offset,
'end_line': end_line,
'next_offset': next_offset,
'max_lines': max_lines,
'max_bytes': max_bytes,
}}))
print(json.dumps({{'ok': True, 'content': f.read()}}))
""".strip()
return await self._run_workspace_file_script(script, query)
@@ -353,27 +307,12 @@ else:
if not any(part in skip_dirs for part in item.parts)
]
hits.sort(key=lambda item: item.stat().st_mtime if item.exists() else 0, reverse=True)
shown = hits[:{_GLOB_MAX_MATCHES}]
shown = hits[:100]
matches = []
output_bytes = 0
truncated_by_bytes = False
for item in shown:
rel = os.path.relpath(str(item), path)
sandbox_path = os.path.join(path, rel).replace(os.sep, '/')
entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if matches else 0)
if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}:
truncated_by_bytes = True
break
matches.append(sandbox_path)
output_bytes += entry_bytes
print(json.dumps({{
'ok': True,
'matches': matches,
'preview': '\\n'.join(matches),
'total': len(hits),
'truncated': len(hits) > len(matches) or truncated_by_bytes,
'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if len(hits) > len(matches) else None),
}}))
matches.append(os.path.join(path, rel).replace(os.sep, '/'))
print(json.dumps({{'ok': True, 'matches': matches, 'total': len(hits), 'truncated': len(hits) > 100}}))
""".strip()
return await self._run_workspace_file_script(script, query)
@@ -411,54 +350,29 @@ else:
continue
if item.is_file():
files.append(item)
if len(files) >= {_GREP_MAX_FILES}:
if len(files) >= 5000:
break
matches = []
output_bytes = 0
truncated_by = None
for fp in files:
try:
handle = fp.open('r', encoding='utf-8', errors='ignore')
text = fp.read_text(errors='ignore')
except OSError:
continue
with handle:
for lineno, line in enumerate(handle, 1):
if regex.search(line):
if base.is_file():
file_path = path
else:
rel = os.path.relpath(str(fp), path)
file_path = os.path.join(path, rel).replace(os.sep, '/')
content = line.rstrip()
line_truncated = False
if len(content) > {_GREP_MAX_LINE_CHARS}:
content = content[:{_GREP_MAX_LINE_CHARS}] + '... [truncated]'
line_truncated = True
entry = {{'file': file_path, 'line': lineno, 'content': content}}
entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1
if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}:
truncated_by = 'bytes'
break
if line_truncated and truncated_by is None:
truncated_by = 'line'
matches.append(entry)
output_bytes += entry_bytes
if len(matches) >= {_GREP_MAX_MATCHES}:
truncated_by = truncated_by or 'matches'
break
if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}:
break
if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}:
for lineno, line in enumerate(text.splitlines(), 1):
if regex.search(line):
if base.is_file():
file_path = path
else:
rel = os.path.relpath(str(fp), path)
file_path = os.path.join(path, rel).replace(os.sep, '/')
matches.append({{'file': file_path, 'line': lineno, 'content': line.rstrip()}})
if len(matches) >= 200:
break
if len(matches) >= 200:
break
print(json.dumps({{
'ok': True,
'matches': matches,
'total': len(matches),
'truncated': truncated_by is not None,
'truncated_by': truncated_by,
}}))
print(json.dumps({{'ok': True, 'matches': matches, 'total': len(matches), 'truncated': len(matches) >= 200}}))
""".strip()
return await self._run_workspace_file_script(script, query)
@@ -473,20 +387,14 @@ else:
)
if skill_request is not None and hasattr(self.ap.box_service, 'read_skill_file'):
selected_skill, relative = skill_request
host_path = self._resolve_skill_host_path(selected_skill, relative)
if host_path and os.path.exists(host_path):
if os.path.isdir(host_path):
return self._build_directory_result(os.listdir(host_path))
return self._read_text_file_preview(host_path, parameters)
try:
result = await self.ap.box_service.read_skill_file(selected_skill['name'], relative)
return self._build_read_result_from_text(str(result.get('content', '')), parameters)
return {'ok': True, 'content': result.get('content', '')}
except Exception:
try:
result = await self.ap.box_service.list_skill_files(selected_skill['name'], relative)
entries = [entry['name'] for entry in result.get('entries', [])]
return self._build_directory_result(entries)
return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True}
except Exception as exc:
return {'ok': False, 'error': str(exc)}
@@ -497,13 +405,15 @@ else:
include_activated=True,
)
if self._should_use_box_workspace_files(selected_skill):
return await self._read_workspace_via_box(path, parameters, query)
return await self._read_workspace_via_box(path, query)
if not os.path.exists(host_path):
return {'ok': False, 'error': f'File not found: {path}'}
if os.path.isdir(host_path):
entries = os.listdir(host_path)
return self._build_directory_result(entries)
return self._read_text_file_preview(host_path, parameters)
return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True}
with open(host_path, 'r', errors='replace') as f:
content = f.read()
return {'ok': True, 'content': content}
async def _invoke_write(self, parameters: dict, query: pipeline_query.Query) -> dict:
path = parameters['path']
@@ -674,28 +584,6 @@ else:
'type': 'string',
'description': 'Absolute path to the file (must be under /workspace).',
},
'offset': {
'type': 'integer',
'description': '1-indexed line number to start reading from. Defaults to 1.',
'default': 1,
'minimum': 1,
},
'limit': {
'type': 'integer',
'description': f'Maximum number of lines to return. Defaults to {_DEFAULT_READ_MAX_LINES}.',
'default': _DEFAULT_READ_MAX_LINES,
'minimum': 1,
'maximum': _MAX_READ_MAX_LINES,
},
'max_bytes': {
'type': 'integer',
'description': (
f'Maximum bytes of file content to return. Defaults to {_DEFAULT_TOOL_RESULT_MAX_BYTES}.'
),
'default': _DEFAULT_TOOL_RESULT_MAX_BYTES,
'minimum': 1,
'maximum': _DEFAULT_TOOL_RESULT_MAX_BYTES,
},
},
'required': ['path'],
'additionalProperties': False,
@@ -852,30 +740,22 @@ else:
hits.sort(key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True)
total = len(hits)
shown = hits[:_GLOB_MAX_MATCHES]
shown = hits[:100]
# Convert back to sandbox paths
sandbox_paths = []
output_bytes = 0
truncated_by_bytes = False
for h in shown:
rel = os.path.relpath(str(h), host_path)
sandbox_path = os.path.join(path, rel)
entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if sandbox_paths else 0)
if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES:
truncated_by_bytes = True
break
sandbox_paths.append(sandbox_path)
output_bytes += entry_bytes
return {
'ok': True,
'matches': sandbox_paths,
'preview': '\n'.join(sandbox_paths),
'total': total,
'truncated': total > len(sandbox_paths) or truncated_by_bytes,
'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if total > len(sandbox_paths) else None),
}
result_lines = sandbox_paths
result = '\n'.join(result_lines)
if total > 100:
result += f'\n... ({total} matches, showing first 100)'
return {'ok': True, 'matches': result_lines, 'total': total, 'truncated': total > 100}
async def _invoke_grep(self, parameters: dict, query: pipeline_query.Query) -> dict:
pattern = parameters['pattern']
@@ -911,46 +791,32 @@ else:
files = self._grep_walk(base, include)
matches = []
output_bytes = 0
truncated_by = None
for fp in files:
try:
handle = fp.open('r', encoding='utf-8', errors='ignore')
text = fp.read_text(errors='ignore')
except OSError:
continue
with handle:
for lineno, line in enumerate(handle, 1):
if regex.search(line):
rel = os.path.relpath(str(fp), host_path)
sandbox_path = os.path.join(path, rel)
content, line_truncated = self._truncate_grep_line(line.rstrip())
entry = {
for lineno, line in enumerate(text.splitlines(), 1):
if regex.search(line):
rel = os.path.relpath(str(fp), host_path)
sandbox_path = os.path.join(path, rel)
matches.append(
{
'file': sandbox_path,
'line': lineno,
'content': content,
'content': line.rstrip(),
}
entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1
if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES:
truncated_by = 'bytes'
break
if line_truncated and truncated_by is None:
truncated_by = 'line'
matches.append(entry)
output_bytes += entry_bytes
if len(matches) >= _GREP_MAX_MATCHES:
truncated_by = truncated_by or 'matches'
break
if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES:
break
if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES:
)
if len(matches) >= 200:
break
if len(matches) >= 200:
break
return {
'ok': True,
'matches': matches,
'total': len(matches),
'truncated': truncated_by is not None,
'truncated_by': truncated_by,
'truncated': len(matches) >= 200,
}
@staticmethod
@@ -962,207 +828,10 @@ else:
continue
if item.is_file():
results.append(item)
if len(results) >= _GREP_MAX_FILES:
if len(results) >= 5000:
break
return results
@staticmethod
def _resolve_skill_host_path(selected_skill: dict, relative: str) -> str | None:
package_root = str(selected_skill.get('package_root', '') or '').strip()
if not package_root:
return None
host_root = os.path.realpath(package_root)
host_path = os.path.realpath(os.path.join(host_root, relative))
if not (host_path == host_root or host_path.startswith(host_root + os.sep)):
raise ValueError('Path escapes the skill package boundary.')
return host_path
def _normalize_exec_result(self, result: dict) -> dict:
normalized = dict(result)
stdout = str(normalized.get('stdout') or '')
stderr = str(normalized.get('stderr') or '')
stdout, stdout_capped = self._truncate_text_to_bytes_with_flag(stdout, _DEFAULT_TOOL_RESULT_MAX_BYTES)
stderr, stderr_capped = self._truncate_text_to_bytes_with_flag(stderr, _DEFAULT_TOOL_RESULT_MAX_BYTES)
normalized['stdout'] = stdout
normalized['stderr'] = stderr
normalized['stdout_truncated'] = bool(normalized.get('stdout_truncated') or stdout_capped)
normalized['stderr_truncated'] = bool(normalized.get('stderr_truncated') or stderr_capped)
if stdout and stderr:
preview_raw = f'stdout:\n{stdout}\n\nstderr:\n{stderr}'
else:
preview_raw = stdout or stderr
preview, preview_capped = self._truncate_text_to_bytes_with_flag(preview_raw, _DEFAULT_TOOL_RESULT_MAX_BYTES)
normalized['preview'] = preview
normalized['truncated'] = bool(
normalized['stdout_truncated'] or normalized['stderr_truncated'] or preview_capped
)
if preview_capped and not normalized.get('truncated_by'):
normalized['truncated_by'] = 'bytes'
return normalized
def _build_directory_result(self, entries: list[str]) -> dict:
sorted_entries = sorted(str(entry) for entry in entries)
content = '\n'.join(sorted_entries)
preview = self._truncate_text_to_bytes(content, _DEFAULT_TOOL_RESULT_MAX_BYTES)
truncated = preview != content
return {
'ok': True,
'content': preview,
'is_directory': True,
'total': len(sorted_entries),
'truncated': truncated,
'truncated_by': 'bytes' if truncated else None,
}
def _read_text_file_preview(self, host_path: str, parameters: dict) -> dict:
offset = self._positive_int(parameters.get('offset'), default=1)
max_lines = self._positive_int(
parameters.get('limit'),
default=_DEFAULT_READ_MAX_LINES,
max_value=_MAX_READ_MAX_LINES,
)
max_bytes = self._positive_int(
parameters.get('max_bytes'),
default=_DEFAULT_TOOL_RESULT_MAX_BYTES,
max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES,
)
lines: list[str] = []
output_bytes = 0
end_line = offset - 1
truncated = False
truncated_by: str | None = None
next_offset: int | None = None
with open(host_path, 'r', encoding='utf-8', errors='replace') as f:
for line_number, line in enumerate(f, 1):
if line_number < offset:
continue
if len(lines) >= max_lines:
truncated = True
truncated_by = 'lines'
next_offset = line_number
break
line_bytes = len(line.encode('utf-8'))
if output_bytes + line_bytes > max_bytes:
truncated = True
truncated_by = 'bytes'
next_offset = line_number
break
lines.append(line.rstrip('\n'))
output_bytes += line_bytes
end_line = line_number
if not lines and truncated_by == 'bytes':
content = (
f'[Line {next_offset or offset} exceeds the {self._format_size(max_bytes)} read limit. '
'Use exec with a byte-range command for this line, or read a different offset.]'
)
else:
content = '\n'.join(lines)
return {
'ok': True,
'content': content,
'truncated': truncated,
'truncated_by': truncated_by,
'start_line': offset,
'end_line': end_line,
'next_offset': next_offset,
'max_lines': max_lines,
'max_bytes': max_bytes,
}
def _build_read_result_from_text(self, content: str, parameters: dict) -> dict:
offset = self._positive_int(parameters.get('offset'), default=1)
max_lines = self._positive_int(
parameters.get('limit'),
default=_DEFAULT_READ_MAX_LINES,
max_value=_MAX_READ_MAX_LINES,
)
max_bytes = self._positive_int(
parameters.get('max_bytes'),
default=_DEFAULT_TOOL_RESULT_MAX_BYTES,
max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES,
)
all_lines = content.splitlines()
start_index = offset - 1
if start_index >= len(all_lines) and all_lines:
return {'ok': False, 'error': f'Offset {offset} is beyond end of file ({len(all_lines)} lines total)'}
output_lines: list[str] = []
output_bytes = 0
truncated = False
truncated_by: str | None = None
next_offset: int | None = None
for index, line in enumerate(all_lines[start_index:], start_index + 1):
if len(output_lines) >= max_lines:
truncated = True
truncated_by = 'lines'
next_offset = index
break
line_bytes = len(line.encode('utf-8')) + (1 if output_lines else 0)
if output_bytes + line_bytes > max_bytes:
truncated = True
truncated_by = 'bytes'
next_offset = index
break
output_lines.append(line)
output_bytes += line_bytes
end_line = offset + len(output_lines) - 1
return {
'ok': True,
'content': '\n'.join(output_lines),
'truncated': truncated,
'truncated_by': truncated_by,
'start_line': offset,
'end_line': end_line,
'next_offset': next_offset,
'max_lines': max_lines,
'max_bytes': max_bytes,
}
@staticmethod
def _positive_int(value, *, default: int, max_value: int | None = None) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = default
if parsed <= 0:
parsed = default
if max_value is not None:
parsed = min(parsed, max_value)
return parsed
@staticmethod
def _truncate_grep_line(line: str) -> tuple[str, bool]:
if len(line) <= _GREP_MAX_LINE_CHARS:
return line, False
return f'{line[:_GREP_MAX_LINE_CHARS]}... [truncated]', True
@staticmethod
def _truncate_text_to_bytes(text: str, max_bytes: int) -> str:
return NativeToolLoader._truncate_text_to_bytes_with_flag(text, max_bytes)[0]
@staticmethod
def _truncate_text_to_bytes_with_flag(text: str, max_bytes: int) -> tuple[str, bool]:
data = text.encode('utf-8')
if len(data) <= max_bytes:
return text, False
truncated = data[:max_bytes]
while truncated and (truncated[-1] & 0xC0) == 0x80:
truncated = truncated[:-1]
return truncated.decode('utf-8', errors='ignore'), True
@staticmethod
def _format_size(bytes_count: int) -> str:
if bytes_count < 1024:
return f'{bytes_count}B'
return f'{bytes_count / 1024:.1f}KB'
def _summarize_parameters(self, parameters: dict) -> dict:
summary = dict(parameters)
cmd = str(summary.get('command', '')).strip()

View File

@@ -72,45 +72,6 @@ def register_activated_skill(query: pipeline_query.Query, skill_data: dict) -> N
activated[skill_name] = skill_data
def normalize_skill_names(value: typing.Any) -> list[str]:
"""Return a de-duplicated list of non-empty skill names."""
if not isinstance(value, list):
return []
names: list[str] = []
for item in value:
skill_name = str(item or '').strip()
if skill_name and skill_name not in names:
names.append(skill_name)
return names
def get_activated_skill_names(query: pipeline_query.Query) -> list[str]:
"""Return activated skill names for callers that own persistence policy."""
return normalize_skill_names(list(get_activated_skills(query).keys()))
def restore_activated_skills(
ap: app.Application,
query: pipeline_query.Query,
skill_names: typing.Any,
) -> list[str]:
"""Restore caller-provided activated skill names into Query variables.
Persistence and state scope ownership belong to higher-level flows. This
helper only rebuilds current Query state from pipeline-visible skills, so
removed or unbound skills stay unavailable to native exec/write/edit.
"""
restored: list[str] = []
for skill_name in normalize_skill_names(skill_names):
skill_data = get_visible_skill(ap, query, skill_name)
if skill_data is None:
continue
register_activated_skill(query, skill_data)
restored.append(skill_name)
return restored
def parse_skill_mount_path(sandbox_path: str) -> tuple[str | None, str]:
normalized_path = str(sandbox_path or '/workspace').strip() or '/workspace'
if normalized_path == SKILL_MOUNT_PREFIX:

View File

@@ -6,7 +6,6 @@ import typing
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from .. import loader
from .availability import is_box_backend_available
# Align with Claude Code's Skill tool design:
# - activate: Activate a skill via Tool Call, returns SKILL.md content
@@ -46,7 +45,18 @@ class SkillToolLoader(loader.ToolLoader):
async def _check_sandbox_available(self) -> bool:
"""Check if the box backend is truly available (not just the runtime)."""
return await is_box_backend_available(self.ap)
box_service = getattr(self.ap, 'box_service', None)
if box_service is None:
return False
if not getattr(box_service, 'available', False):
return False
# Check if backend is truly available via get_status
try:
status = await box_service.get_status()
backend_info = status.get('backend', {})
return backend_info.get('available', False)
except Exception:
return False
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
if not self._is_available():
@@ -82,15 +92,16 @@ class SkillToolLoader(loader.ToolLoader):
if not skill_name:
raise ValueError('skill_name is required')
from . import skill as skill_loader
skill_data = skill_loader.get_visible_skill(self.ap, query, skill_name)
skill_mgr = self.ap.skill_mgr
skill_data = skill_mgr.get_skill_by_name(skill_name)
if skill_data is None:
visible_skills = skill_loader.get_visible_skills(self.ap, query)
visible_skills = getattr(skill_mgr, 'skills', {})
available_names = ', '.join(sorted(visible_skills.keys())) or 'none'
raise ValueError(f'Skill "{skill_name}" not found. Available skills: {available_names}')
# Register activated skill for sandbox mount path resolution
from . import skill as skill_loader
skill_loader.register_activated_skill(query, skill_data)
# Return SKILL.md content as Tool Result (injects into context)
@@ -116,7 +127,6 @@ class SkillToolLoader(loader.ToolLoader):
'activated': True,
'skill_name': skill_name,
'mount_path': mount_path,
'activated_skill_names': skill_loader.get_activated_skill_names(query),
'content': result_content,
}
@@ -191,13 +201,13 @@ class SkillToolLoader(loader.ToolLoader):
return resource_tool.LLMTool(
name=ACTIVATE_SKILL_TOOL_NAME,
human_desc='Activate a skill',
description='Activate a pipeline-visible skill by name and return its instructions as a tool result.',
description=self._build_activate_tool_description(),
parameters={
'type': 'object',
'properties': {
'skill_name': {
'type': 'string',
'description': 'The skill name to activate.',
'description': 'The skill name to activate (no arguments). E.g., "pdf" or "data-analysis"',
},
},
'required': ['skill_name'],
@@ -245,3 +255,50 @@ class SkillToolLoader(loader.ToolLoader):
},
func=lambda parameters: parameters,
)
def _build_activate_tool_description(self) -> str:
"""Build tool description with embedded available_skills list."""
skill_mgr = getattr(self.ap, 'skill_mgr', None)
if skill_mgr is None:
return 'Activate a skill. No skills are currently available.'
skills = getattr(skill_mgr, 'skills', {})
if not skills:
return 'Activate a skill. No skills are currently available.'
# Build <available_skills> section
available_skills_lines = ['<available_skills>']
for skill_name, skill_data in sorted(skills.items()):
description = skill_data.get('description', '')
available_skills_lines.append('<skill>')
available_skills_lines.append(f'<name>{skill_name}</name>')
available_skills_lines.append(f'<description>{description}</description>')
available_skills_lines.append('</skill>')
available_skills_lines.append('</available_skills>')
available_skills_block = '\n'.join(available_skills_lines)
return f"""Activate a skill within the main conversation.
<skills_instructions>
When users ask you to perform tasks, check if any of the available skills
below can help complete the task more effectively. Skills provide specialized
capabilities and domain knowledge.
How to use skills:
- Invoke skills using this tool with the skill name only (no arguments)
- When you invoke a skill, you will see <command-message>
The skill is activated
</command-message>
- The skill's instructions will be provided in the tool result
- Examples:
- skill_name: "pdf" - invoke the pdf skill
- skill_name: "data-analysis" - invoke the data-analysis skill
Important:
- Only use skills listed in <available_skills> below
- Do not invoke a skill that is already running
- To create a new skill: prepare it in /workspace, then use register_skill tool
</skills_instructions>
{available_skills_block}"""

View File

@@ -5,6 +5,7 @@ import traceback
import uuid
import zipfile
import io
import datetime
from typing import Any
from langbot.pkg.core import app
import sqlalchemy
@@ -25,6 +26,10 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
super().__init__(ap)
self.knowledge_base_entity = knowledge_base_entity
@staticmethod
def _utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
async def initialize(self):
pass
@@ -334,6 +339,24 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
# are passed directly to vector_search by some plugins (e.g. LangRAG)
# and would cause empty results when the metadata field doesn't exist.
filters = settings.pop('filters', {})
trace_context = settings.pop('_trace_context', None)
host_span_started_at = self._utc_now()
host_span_id = None
if trace_context and trace_context.get('trace_id'):
host_parent_span_id = trace_context.get('parent_span_id')
host_span_id = trace_context.get('rag_span_id') or f'span-{uuid.uuid4().hex[:16]}'
trace_context = {
'trace_id': trace_context.get('trace_id'),
'parent_span_id': host_span_id,
'host_parent_span_id': host_parent_span_id,
'message_id': trace_context.get('message_id'),
'query_id': trace_context.get('query_id'),
'session_id': trace_context.get('session_id'),
'bot_id': trace_context.get('bot_id'),
'pipeline_id': trace_context.get('pipeline_id'),
'knowledge_base_id': kb.uuid,
'attributes': trace_context.get('attributes') or {},
}
retrieval_context = {
'query': query,
@@ -343,13 +366,104 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
'creation_settings': kb.creation_settings or {},
'filters': filters,
}
if trace_context:
retrieval_context['trace_context'] = trace_context
result = await self.ap.plugin_connector.call_rag_retrieve(
plugin_id,
retrieval_context,
)
try:
result = await self.ap.plugin_connector.call_rag_retrieve(
plugin_id,
retrieval_context,
)
except Exception as e:
if trace_context:
await self._record_rag_trace_result(
trace_context=trace_context,
host_span_id=host_span_id,
started_at=host_span_started_at,
plugin_id=plugin_id,
result={
'results': [],
'metadata': {
'status': 'error',
'error_message': str(e),
},
},
)
raise
if trace_context:
await self._record_rag_trace_result(
trace_context=trace_context,
host_span_id=host_span_id,
started_at=host_span_started_at,
plugin_id=plugin_id,
result=result,
)
return result
async def _record_rag_trace_result(
self,
trace_context: dict[str, Any],
host_span_id: str | None,
started_at: datetime.datetime,
plugin_id: str,
result: dict[str, Any],
) -> None:
"""Persist host RAG span and plugin-provided child spans."""
trace_id = trace_context.get('trace_id')
if not trace_id:
return
metadata = result.get('metadata') if isinstance(result, dict) else {}
metadata = metadata if isinstance(metadata, dict) else {}
plugin_spans = metadata.get('trace_spans') if isinstance(metadata.get('trace_spans'), list) else []
parent_span_id = trace_context.get('parent_span_id')
host_parent_span_id = trace_context.get('host_parent_span_id')
try:
await self.ap.monitoring_service.record_span(
trace_id=trace_id,
span_id=host_span_id,
parent_span_id=host_parent_span_id,
name=f'Knowledge retrieval {self.knowledge_base_entity.name}',
kind='rag.retrieval',
status=metadata.get('status', 'success'),
started_at=started_at,
duration=metadata.get('duration_ms'),
message_id=trace_context.get('message_id'),
session_id=trace_context.get('session_id'),
bot_id=trace_context.get('bot_id'),
pipeline_id=trace_context.get('pipeline_id'),
attributes={
'knowledge_base_id': self.knowledge_base_entity.uuid,
'knowledge_base_name': self.knowledge_base_entity.name,
'plugin_id': plugin_id,
'returned_count': len(result.get('results', []) if isinstance(result, dict) else []),
'total_found': result.get('total_found') if isinstance(result, dict) else None,
},
error_message=metadata.get('error_message'),
)
for span in plugin_spans:
if not isinstance(span, dict):
continue
await self.ap.monitoring_service.record_span(
trace_id=trace_id,
span_id=span.get('span_id'),
parent_span_id=span.get('parent_span_id') or host_span_id or parent_span_id,
name=span.get('name') or 'RAG plugin stage',
kind=span.get('kind') or 'rag.stage',
status=span.get('status') or 'success',
started_at=started_at,
duration=span.get('duration_ms'),
message_id=trace_context.get('message_id'),
session_id=trace_context.get('session_id'),
bot_id=trace_context.get('bot_id'),
pipeline_id=trace_context.get('pipeline_id'),
attributes=span.get('attributes') if isinstance(span.get('attributes'), dict) else {},
error_message=span.get('error_message'),
)
except Exception as e:
self.ap.logger.error(f'Failed to record RAG trace spans: {e}')
async def _delete_document(self, document_id: str) -> bool:
"""Call plugin to delete document."""
kb = self.knowledge_base_entity

View File

@@ -1,6 +1,7 @@
# LangBot Test Suite
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
This directory contains the LangBot backend test suite, including unit tests,
integration tests, startup E2E tests, and container-backed Box runtime tests.
## Quality Gate Layers
@@ -10,10 +11,15 @@ LangBot uses a layered quality gate system for developers and CI:
|-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
**Note**: PostgreSQL migration tests and slow tests are NOT in local default
gates. They run in separate CI workflows. Frontend Playwright tests live under
`web/tests/e2e` and are documented in `web/README.md`.
### Developer Workflow
@@ -28,6 +34,9 @@ make test-all-local
bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min
uv run --python 3.12 pytest tests/e2e -q --tb=short
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
cd web && pnpm test:e2e
```
### Coverage Baseline
@@ -70,6 +79,12 @@ tests/
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── e2e/ # Real LangBot startup E2E tests
│ ├── conftest.py
│ ├── test_startup.py
│ └── utils/
├── integration_tests/ # Container-backed integration tests
│ └── box/ # Box runtime and MCP process tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
@@ -303,6 +318,44 @@ These tests:
- Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys
### Running backend E2E startup tests
Backend E2E tests start a real LangBot process with a generated minimal
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
They do not require provider keys or external services.
```bash
uv run --python 3.12 pytest tests/e2e -q --tb=short
```
These tests verify startup orchestration, migrations, API route registration,
and the minimal no-LLM startup path. The E2E process manager disables ambient
proxy variables for subprocess startup and uses direct localhost HTTP clients,
so local proxy settings should not affect the health checks.
### Running Box integration tests
Box integration tests exercise the real sandbox runtime path, including command
execution, session persistence, managed process WebSocket attachment, and
cleanup behavior.
```bash
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
```
These tests require a working Docker or Podman runtime. In CI, the dedicated
Box integration job checks Docker availability before running the tests.
### Running frontend E2E tests
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
and mock the LangBot backend and Space APIs, so no backend process is required.
```bash
cd web
pnpm test:e2e
```
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
@@ -320,6 +373,9 @@ Tests are automatically run on:
- Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
they exercise process/container behavior instead of pure Python compatibility.
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
## Adding New Tests
@@ -406,4 +462,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
- [ ] Add E2E tests
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis
- [ ] Add property-based testing with Hypothesis

View File

@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0) as client:
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client:
yield client
@pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db'
return e2e_tmpdir / 'data' / 'langbot.db'

View File

@@ -38,12 +38,13 @@ class TestStartupFlow:
# System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, e2e_db_path):
def test_database_initialized(self, langbot_process, e2e_db_path):
"""Verify SQLite database was created and initialized."""
assert e2e_db_path.exists()
# Database should have some tables after migration
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
@@ -74,10 +75,13 @@ class TestStartupFlow:
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint."""
# First startup may allow initial setup
response = e2e_client.post('/api/v1/user/auth', json={
'username': 'admin',
'password': 'admin',
})
response = e2e_client.post(
'/api/v1/user/auth',
json={
'user': 'admin',
'password': 'admin',
},
)
# Response could be:
# - 200 if auth succeeds
@@ -94,9 +98,10 @@ class TestStartupStages:
# If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, e2e_db_path):
def test_migrations_applied(self, langbot_process, e2e_db_path):
"""Verify database migrations were applied."""
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()

View File

@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
for path in directories.values():
path.mkdir(parents=True, exist_ok=True)
return directories
return directories

View File

@@ -44,6 +44,17 @@ class LangBotProcess:
# Prepare environment
env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src')
for proxy_key in (
'HTTP_PROXY',
'HTTPS_PROXY',
'ALL_PROXY',
'http_proxy',
'https_proxy',
'all_proxy',
):
env.pop(proxy_key, None)
env['NO_PROXY'] = '127.0.0.1,localhost'
env['no_proxy'] = '127.0.0.1,localhost'
# Set API port via environment variable
env['API__PORT'] = str(self.port)
@@ -79,9 +90,11 @@ precision = 2
f.write(coveragerc_content)
cmd = [
'coverage', 'run',
'coverage',
'run',
'--rcfile=' + str(coveragerc_path),
'-m', 'langbot',
'-m',
'langbot',
]
else:
cmd = ['uv', 'run', 'python', '-m', 'langbot']
@@ -113,6 +126,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0,
follow_redirects=False,
trust_env=False,
)
if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}')
@@ -185,6 +200,8 @@ precision = 2
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0,
follow_redirects=False,
trust_env=False,
)
return r.status_code == 200
except Exception:
@@ -201,4 +218,4 @@ def find_project_root() -> Path:
return parent
# Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build')
return Path('/home/glwuy/langbot-app/LangBot-test-build')

View File

@@ -58,45 +58,45 @@ from tests.factories.platform import (
__all__ = [
# App
"FakeApp",
"fake_app",
'FakeApp',
'fake_app',
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
'text_chain',
'group_text_chain',
'mention_chain',
'image_chain',
# Message events
"friend_message_event",
"group_message_event",
'friend_message_event',
'group_message_event',
# Mock adapters
"mock_adapter",
'mock_adapter',
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
"image_query",
"file_query",
"unsupported_query",
"voice_query",
"at_all_query",
"query_with_session",
"query_with_config",
'text_query',
'group_text_query',
'private_text_query',
'command_query',
'mention_query',
'empty_query',
'image_query',
'file_query',
'unsupported_query',
'voice_query',
'at_all_query',
'query_with_session',
'query_with_config',
# Provider
"FakeProvider",
"fake_provider",
"fake_provider_pong",
"fake_provider_timeout",
"fake_provider_auth_error",
"fake_provider_rate_limit",
"fake_provider_malformed",
"fake_model",
'FakeProvider',
'fake_provider',
'fake_provider_pong',
'fake_provider_timeout',
'fake_provider_auth_error',
'fake_provider_rate_limit',
'fake_provider_malformed',
'fake_model',
# Platform
"FakePlatform",
"fake_platform",
"fake_platform_with_streaming",
"fake_platform_with_failure",
"mock_platform_adapter",
]
'FakePlatform',
'fake_platform',
'fake_platform_with_streaming',
'fake_platform_with_failure',
'mock_platform_adapter',
]

View File

@@ -30,32 +30,36 @@ def _next_query_id() -> int:
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
def text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
return platform_message.MessageChain(
[
platform_message.At(target=target),
platform_message.Plain(text=f' {text}'),
]
)
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
@@ -66,13 +70,15 @@ def image_chain(
def command_chain(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
return platform_message.MessageChain(
[
platform_message.Plain(text=f'{prefix}{command}'),
]
)
# ============== Message Event Factories ==============
@@ -81,7 +87,7 @@ def command_chain(
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
@@ -90,7 +96,7 @@ def friend_message_event(
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -100,9 +106,9 @@ def friend_message_event(
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
@@ -117,7 +123,7 @@ def group_message_event(
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=message_chain,
time=1609459200,
@@ -152,36 +158,36 @@ def _base_query(
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'query_id': query_id,
'launcher_type': launcher_type,
'launcher_id': launcher_id,
'sender_id': sender_id,
'message_chain': message_chain,
'message_event': message_event,
'adapter': adapter,
'pipeline_uuid': 'test-pipeline-uuid',
'bot_uuid': 'test-bot-uuid',
'pipeline_config': {
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
'session': None,
'prompt': None,
'messages': [],
'user_message': None,
'use_funcs': [],
'use_llm_model_uuid': None,
'variables': {},
'resp_messages': [],
'resp_message_chain': None,
'current_stage_name': None,
}
# Apply overrides
@@ -192,7 +198,7 @@ def _base_query(
def text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -212,7 +218,7 @@ def text_query(
def private_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -221,7 +227,7 @@ def private_text_query(
def group_text_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -242,8 +248,8 @@ def group_text_query(
def command_query(
command: str = "help",
prefix: str = "/",
command: str = 'help',
prefix: str = '/',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -263,7 +269,7 @@ def command_query(
def mention_query(
text: str = "hello",
text: str = 'hello',
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
@@ -301,8 +307,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
text: str = '',
url: str = 'https://example.com/image.png',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -322,9 +328,9 @@ def image_query(
def file_query(
url: str = "https://example.com/document.pdf",
name: str = "document.pdf",
text: str = "",
url: str = 'https://example.com/document.pdf',
name: str = 'document.pdf',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -348,8 +354,8 @@ def file_query(
def unsupported_query(
unsupported_type: str = "CustomComponent",
text: str = "",
unsupported_type: str = 'CustomComponent',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -358,7 +364,7 @@ def unsupported_query(
if text:
components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}'))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
@@ -374,7 +380,7 @@ def unsupported_query(
def query_with_session(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None,
**overrides,
@@ -389,7 +395,7 @@ def query_with_session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -398,7 +404,7 @@ def query_with_session(
def query_with_config(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None,
**overrides,
@@ -410,22 +416,22 @@ def query_with_config(
"""
if pipeline_config is None:
pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'test-prompt',
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
'trigger': {'misc': {'combine-quote-message': False}},
}
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query(
url: str = "https://example.com/audio.mp3",
url: str = 'https://example.com/audio.mp3',
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
@@ -448,7 +454,7 @@ def voice_query(
def at_all_query(
text: str = "hello",
text: str = 'hello',
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
@@ -456,7 +462,7 @@ def at_all_query(
"""Create a group query with @All mention."""
components = [
platform_message.AtAll(),
platform_message.Plain(text=f" {text}"),
platform_message.Plain(text=f' {text}'),
]
chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id)
@@ -469,4 +475,4 @@ def at_all_query(
sender_id=sender_id,
adapter=adapter,
**overrides,
)
)

View File

@@ -33,7 +33,7 @@ class FakePlatform:
def __init__(
self,
*,
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
raise_error: Exception = None,
):
@@ -48,16 +48,16 @@ class FakePlatform:
# Registered listeners
self._listeners: dict = {}
def raises(self, error: Exception) -> "FakePlatform":
def raises(self, error: Exception) -> 'FakePlatform':
"""Configure platform to raise an error on send."""
self._raise_error = error
return self
def send_failure(self) -> "FakePlatform":
def send_failure(self) -> 'FakePlatform':
"""Configure platform to simulate send failure."""
return self.raises(Exception("Platform send failure"))
return self.raises(Exception('Platform send failure'))
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
def supports_streaming(self, supported: bool = True) -> 'FakePlatform':
"""Configure whether streaming output is supported."""
self._stream_output_supported = supported
return self
@@ -89,7 +89,7 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
nickname: str = 'TestUser',
) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event."""
sender = platform_entities.Friend(
@@ -97,11 +97,13 @@ class FakePlatform:
nickname=nickname,
remark=None,
)
chain = platform_message.MessageChain([
platform_message.Plain(text=text),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text=text),
]
)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -111,9 +113,9 @@ class FakePlatform:
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
sender_name: str = 'TestUser',
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
group_name: str = 'TestGroup',
mention_bot: bool = False,
) -> platform_events.GroupMessage:
"""Create an inbound group message event.
@@ -142,12 +144,12 @@ class FakePlatform:
components = []
if mention_bot:
components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=' '))
components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components)
return platform_events.GroupMessage(
type="GroupMessage",
type='GroupMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -155,8 +157,8 @@ class FakePlatform:
def create_image_message(
self,
url: str = "https://example.com/image.png",
text: str = "",
url: str = 'https://example.com/image.png',
text: str = '',
sender_id: typing.Union[int, str] = 12345,
is_group: bool = False,
group_id: typing.Union[int, str] = 99999,
@@ -169,12 +171,12 @@ class FakePlatform:
chain = platform_message.MessageChain(components)
if is_group:
return self.create_group_message("", sender_id, group_id=group_id)
return self.create_group_message('', sender_id, group_id=group_id)
# Replace chain
else:
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None)
return platform_events.FriendMessage(
type="FriendMessage",
type='FriendMessage',
sender=sender,
message_chain=chain,
time=1609459200,
@@ -192,12 +194,14 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "send",
"target_type": target_type,
"target_id": target_id,
"message": message,
})
self._outbound_messages.append(
{
'type': 'send',
'target_type': target_type,
'target_id': target_id,
'message': message,
}
)
async def reply_message(
self,
@@ -209,13 +213,15 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "reply",
"source_type": message_source.type,
"source": message_source,
"message": message,
"quote_origin": quote_origin,
})
self._outbound_messages.append(
{
'type': 'reply',
'source_type': message_source.type,
'source': message_source,
'message': message,
'quote_origin': quote_origin,
}
)
async def reply_message_chunk(
self,
@@ -229,15 +235,17 @@ class FakePlatform:
if self._raise_error:
raise self._raise_error
self._outbound_chunks.append({
"type": "reply_chunk",
"source_type": message_source.type,
"source": message_source,
"bot_message": bot_message,
"message": message,
"quote_origin": quote_origin,
"is_final": is_final,
})
self._outbound_chunks.append(
{
'type': 'reply_chunk',
'source_type': message_source.type,
'source': message_source,
'bot_message': bot_message,
'message': message,
'quote_origin': quote_origin,
'is_final': is_final,
}
)
async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported."""
@@ -295,7 +303,7 @@ class FakePlatform:
def fake_platform(
bot_account_id: str = "test-bot",
bot_account_id: str = 'test-bot',
stream_output_supported: bool = False,
) -> FakePlatform:
"""Create a FakePlatform instance."""
@@ -328,9 +336,7 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported)
adapter._fake_platform = platform # Store for assertions
return adapter
return adapter

View File

@@ -27,51 +27,51 @@ class FakeProvider:
Does not require API keys.
"""
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
def __init__(
self,
*,
default_response: str = "fake response",
default_response: str = 'fake response',
streaming_chunks: list[str] = None,
raise_error: Exception = None,
captured_requests: list = None,
):
self._default_response = default_response
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> "FakeProvider":
def returns(self, text: str) -> 'FakeProvider':
"""Configure provider to return a specific text response."""
self._default_response = text
self._streaming_chunks = [text]
return self
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
"""Configure provider to return streaming chunks."""
self._streaming_chunks = chunks
self._default_response = "".join(chunks)
self._default_response = ''.join(chunks)
return self
def raises(self, error: Exception) -> "FakeProvider":
def raises(self, error: Exception) -> 'FakeProvider':
"""Configure provider to raise an error."""
self._raise_error = error
return self
def timeout(self) -> "FakeProvider":
def timeout(self) -> 'FakeProvider':
"""Configure provider to simulate timeout."""
return self.raises(TimeoutError("Provider timeout"))
return self.raises(TimeoutError('Provider timeout'))
def auth_error(self) -> "FakeProvider":
def auth_error(self) -> 'FakeProvider':
"""Configure provider to simulate auth error."""
return self.raises(Exception("Invalid API key"))
return self.raises(Exception('Invalid API key'))
def rate_limit(self) -> "FakeProvider":
def rate_limit(self) -> 'FakeProvider':
"""Configure provider to simulate rate limit."""
return self.raises(Exception("Rate limit exceeded"))
return self.raises(Exception('Rate limit exceeded'))
def malformed(self) -> "FakeProvider":
def malformed(self) -> 'FakeProvider':
"""Configure provider to simulate malformed response."""
self._default_response = None
return self
@@ -87,7 +87,7 @@ class FakeProvider:
def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content."""
return provider_message.Message(
role="assistant",
role='assistant',
content=content,
)
@@ -99,7 +99,7 @@ class FakeProvider:
) -> provider_message.MessageChunk:
"""Create a provider message chunk."""
return provider_message.MessageChunk(
role="assistant",
role='assistant',
content=content,
is_final=is_final,
msg_sequence=msg_sequence,
@@ -116,13 +116,15 @@ class FakeProvider:
) -> provider_message.Message:
"""Simulate non-streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
}
)
# Simulate error if configured
if self._raise_error:
@@ -131,7 +133,7 @@ class FakeProvider:
# Return response
if self._default_response is None:
# Malformed response
return provider_message.Message(role="assistant", content=None)
return provider_message.Message(role='assistant', content=None)
return self._create_message(self._default_response)
@@ -146,14 +148,16 @@ class FakeProvider:
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
"streaming": True,
})
self._captured_requests.append(
{
'query_id': query.query_id if query else None,
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
'messages': messages,
'funcs': funcs,
'extra_args': extra_args,
'streaming': True,
}
)
# Simulate error if configured
if self._raise_error:
@@ -161,12 +165,12 @@ class FakeProvider:
# Yield chunks
for i, chunk in enumerate(self._streaming_chunks):
is_final = (i == len(self._streaming_chunks) - 1)
is_final = i == len(self._streaming_chunks) - 1
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider(
default_response: str = "fake response",
default_response: str = 'fake response',
) -> FakeProvider:
"""Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response)
@@ -202,8 +206,8 @@ def fake_provider_malformed() -> FakeProvider:
def fake_model(
*,
uuid: str = "test-model-uuid",
name: str = "test-model",
uuid: str = 'test-model-uuid',
name: str = 'test-model',
abilities: list[str] = None,
provider: FakeProvider = None,
) -> Mock:
@@ -212,7 +216,7 @@ def fake_model(
model.model_entity = Mock()
model.model_entity.uuid = uuid
model.model_entity.name = name
model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.abilities = abilities or ['func_call', 'vision']
model.model_entity.extra_args = {}
# Attach fake provider
@@ -221,4 +225,4 @@ def fake_model(
model.provider = provider
return model
return model

View File

@@ -3,4 +3,4 @@ Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""
"""

View File

@@ -2,4 +2,4 @@
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""
"""

View File

@@ -48,6 +48,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield
@@ -56,10 +57,12 @@ def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -71,28 +74,29 @@ def fake_bot_app():
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[
{
app.bot_service.get_bots = AsyncMock(
return_value=[
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
]
)
app.bot_service.get_runtime_bot_info = AsyncMock(
return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
}
])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
)
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1))
app.bot_service.send_message = AsyncMock()
# Platform manager
@@ -118,10 +122,7 @@ class TestBotEndpoints:
@pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'})
assert response.status_code == 200
data = await response.get_json()
@@ -135,7 +136,7 @@ class TestBotEndpoints:
response = await quart_test_client.post(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'},
)
assert response.status_code == 200
@@ -147,8 +148,7 @@ class TestBotEndpoints:
async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -162,7 +162,7 @@ class TestBotEndpoints:
response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}
json={'name': 'Updated Bot'},
)
assert response.status_code == 200
@@ -173,8 +173,7 @@ class TestBotEndpoints:
async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -190,7 +189,7 @@ class TestBotLogsEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}
json={'from_index': -1, 'max_count': 10},
)
assert response.status_code == 200
@@ -213,8 +212,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'person',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 200
@@ -228,7 +227,7 @@ class TestBotSendMessageEndpoint:
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]},
)
assert response.status_code == 400
@@ -244,8 +243,8 @@ class TestBotSendMessageEndpoint:
json={
'target_type': 'invalid',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
'message_chain': [{'type': 'text', 'text': 'Hello'}],
},
)
assert response.status_code == 400

View File

@@ -47,6 +47,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield
@@ -55,10 +56,12 @@ def fake_embed_app():
"""Create FakeApp with embed widget services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock()
@@ -83,9 +86,7 @@ def fake_embed_app():
# WebSocket proxy bot with adapter
mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}])
mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock()
@@ -117,9 +118,7 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js')
assert response.status_code == 200
assert 'javascript' in response.content_type
@@ -127,18 +126,14 @@ class TestEmbedWidgetEndpoint:
@pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js')
assert response.status_code == 400
@pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js')
assert response.status_code == 404
@@ -164,8 +159,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 200
@@ -177,8 +171,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'}
)
assert response.status_code == 400
@@ -187,8 +180,7 @@ class TestEmbedTurnstileVerifyEndpoint:
async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={}
)
assert response.status_code == 400
@@ -203,7 +195,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/person returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -216,7 +208,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages/group returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -226,7 +218,7 @@ class TestEmbedMessagesEndpoint:
"""GET messages with invalid session_type returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 400
@@ -241,7 +233,7 @@ class TestEmbedResetEndpoint:
"""POST reset/person resets session."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
headers={'Authorization': 'Bearer 1234567890.dummy'},
)
assert response.status_code == 200
@@ -252,8 +244,7 @@ class TestEmbedResetEndpoint:
async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@@ -269,7 +260,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}
json={'message_id': 'msg-123', 'feedback_type': 1},
)
assert response.status_code == 200
@@ -283,7 +274,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}
json={'message_id': 'msg-123', 'feedback_type': 2},
)
assert response.status_code == 200
@@ -294,7 +285,7 @@ class TestEmbedFeedbackEndpoint:
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}
json={'message_id': 'msg-123', 'feedback_type': 99},
)
assert response.status_code == 400

View File

@@ -49,6 +49,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,33 +75,35 @@ def fake_knowledge_app():
# Knowledge service
app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
{
app.knowledge_service.get_knowledge_bases = AsyncMock(
return_value=[
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
]
)
app.knowledge_service.get_knowledge_base = AsyncMock(
return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
)
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
])
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}]
)
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}])
# RAG manager
app.rag_mgr = Mock()
@@ -124,8 +129,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -140,7 +144,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'},
)
assert response.status_code == 200
@@ -152,8 +156,7 @@ class TestKnowledgeBaseEndpoints:
async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -167,7 +170,7 @@ class TestKnowledgeBaseEndpoints:
response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}
json={'name': 'Updated KB'},
)
assert response.status_code == 200
@@ -178,8 +181,7 @@ class TestKnowledgeBaseEndpoints:
async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -193,8 +195,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -208,7 +209,7 @@ class TestKnowledgeBaseFilesEndpoints:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'},
)
assert response.status_code == 200
@@ -220,8 +221,7 @@ class TestKnowledgeBaseFilesEndpoints:
async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}},
)
assert response.status_code == 200
@@ -249,9 +249,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={}
)
assert response.status_code == 400

View File

@@ -46,6 +46,7 @@ def mock_circular_import_chain():
clear=clear,
):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield
@@ -54,10 +55,12 @@ def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock()
@@ -67,40 +70,43 @@ def fake_monitoring_app():
# Monitoring service
app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
})
app.monitoring_service.get_messages = AsyncMock(return_value=(
[{'id': 'msg-1', 'content': 'test'}], 100
))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
[{'id': 'llm-1'}], 50
))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
[{'id': 'emb-1'}], 10
))
app.monitoring_service.get_sessions = AsyncMock(return_value=(
[{'session_id': 'sess-1'}], 20
))
app.monitoring_service.get_errors = AsyncMock(return_value=(
[{'id': 'err-1'}], 2
))
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'found': True,
'session_id': 'sess-1',
})
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_overview_metrics = AsyncMock(
return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
}
)
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10))
app.monitoring_service.get_traces = AsyncMock(return_value=([{'trace_id': 'trace-1'}], 1))
app.monitoring_service.get_trace_details = AsyncMock(
return_value={
'found': True,
'trace_id': 'trace-1',
'trace': {'trace_id': 'trace-1'},
'spans': [],
}
)
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20))
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2))
app.monitoring_service.get_session_analysis = AsyncMock(
return_value={
'found': True,
'session_id': 'sess-1',
}
)
app.monitoring_service.get_message_details = AsyncMock(
return_value={
'found': True,
'message_id': 'msg-1',
}
)
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
@@ -130,8 +136,7 @@ class TestMonitoringOverviewEndpoint:
async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get(
'/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -147,8 +152,7 @@ class TestMonitoringMessagesEndpoint:
async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -165,8 +169,7 @@ class TestMonitoringLLMCallsEndpoint:
async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -180,8 +183,7 @@ class TestMonitoringEmbeddingCallsEndpoint:
async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -195,8 +197,7 @@ class TestMonitoringSessionsEndpoint:
async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -210,8 +211,7 @@ class TestMonitoringErrorsEndpoint:
async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors."""
response = await quart_test_client.get(
'/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -225,13 +225,13 @@ class TestMonitoringAllDataEndpoint:
async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get(
'/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert 'overview' in data['data']
assert 'traces' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
@@ -242,8 +242,7 @@ class TestMonitoringDetailsEndpoints:
async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -252,8 +251,16 @@ class TestMonitoringDetailsEndpoints:
async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_trace_details(self, quart_test_client):
"""GET /api/v1/monitoring/traces/{id}."""
response = await quart_test_client.get(
'/api/v1/monitoring/traces/trace-1', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -267,8 +274,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -277,8 +283,7 @@ class TestMonitoringFeedbackEndpoints:
async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -292,8 +297,7 @@ class TestMonitoringExportEndpoint:
async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -303,8 +307,7 @@ class TestMonitoringExportEndpoint:
async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -313,8 +316,7 @@ class TestMonitoringExportEndpoint:
async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -323,8 +325,7 @@ class TestMonitoringExportEndpoint:
async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200

View File

@@ -49,6 +49,7 @@ def mock_circular_import_chain():
):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield
@@ -57,10 +58,12 @@ def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# Auth services
app.user_service = Mock()
@@ -72,27 +75,23 @@ def fake_provider_app():
# Provider service
app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock(return_value=[
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
])
app.provider_service.get_provider = AsyncMock(return_value={
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
})
app.provider_service.get_providers = AsyncMock(
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}]
)
app.provider_service.get_provider = AsyncMock(
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
)
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
})
app.provider_service.get_provider_model_counts = AsyncMock(
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0}
)
# LLM model service
app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}])
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock()
@@ -133,8 +132,7 @@ class TestProviderEndpoints:
async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -157,8 +155,7 @@ class TestProviderEndpoints:
async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -177,7 +174,7 @@ class TestProviderEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}
json={'name': 'New Provider', 'requester': 'chatcmpl'},
)
assert response.status_code == 200
@@ -194,7 +191,7 @@ class TestProviderEndpoints:
response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}
json={'name': 'Updated Provider'},
)
assert response.status_code == 200
@@ -205,8 +202,7 @@ class TestProviderEndpoints:
async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -215,8 +211,7 @@ class TestProviderEndpoints:
async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -237,8 +232,7 @@ class TestModelEndpoints:
async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -250,8 +244,7 @@ class TestModelEndpoints:
async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -264,7 +257,7 @@ class TestModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -276,8 +269,7 @@ class TestModelEndpoints:
async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -291,8 +283,7 @@ class TestEmbeddingModelEndpoints:
async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -306,7 +297,7 @@ class TestEmbeddingModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200
@@ -323,8 +314,7 @@ class TestRerankModelEndpoints:
async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@@ -338,7 +328,7 @@ class TestRerankModelEndpoints:
response = await quart_test_client.post(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'},
)
assert response.status_code == 200

View File

@@ -20,6 +20,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -69,6 +70,7 @@ def mock_circular_import_chain():
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
@@ -79,12 +81,14 @@ def fake_api_app():
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
app.instance_config.data.update(
{
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
}
)
# API-specific services
app.user_service = Mock()
@@ -118,6 +122,7 @@ def fake_api_app():
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls):
"""
@@ -135,6 +140,7 @@ async def quart_test_client(fake_api_app, http_controller_cls):
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@@ -222,8 +228,7 @@ class TestProtectedEndpoints:
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@@ -254,10 +259,7 @@ class TestInvalidPayload:
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'})
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)

View File

@@ -2,4 +2,4 @@
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""
"""

View File

@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
db_file = tmp_path / 'test_migrations.db'
return f'sqlite+aiosqlite:///{db_file}'
@pytest.fixture
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev.startswith('0006'), f'Expected head to be 0006_*, got {rev}'
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestSQLiteMigrationFreshDatabase:
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_migrations_fresh.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
assert rev is not None, 'Expected a revision on fresh DB'
await fresh_engine.dispose()
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_db_file = tmp_path / 'test_empty_migrations.db'
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
# Verify specific behavior - one of two outcomes is expected
if actual_result is not None:
# Migration succeeded - verify revision exists
assert actual_result is not None, "Revision should exist after successful migration"
assert actual_result is not None, 'Revision should exist after successful migration'
else:
# Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables
assert actual_error is not None, "Error should be captured if migration failed"
assert actual_error is not None, 'Error should be captured if migration failed'
# Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios
acceptable_errors = [
'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors
'CommandError', # Alembic command errors
]
assert error_type in acceptable_errors, (
f"Unexpected error type: {error_type}. "
f"This may indicate a regression in migration behavior. "
f"Error: {actual_error}"
f'Unexpected error type: {error_type}. '
f'This may indicate a regression in migration behavior. '
f'Error: {actual_error}'
)
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'

View File

@@ -34,14 +34,14 @@ def postgres_url():
"""Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL')
if not url:
pytest.skip("TEST_POSTGRES_URL not set")
pytest.skip('TEST_POSTGRES_URL not set')
return url
@pytest.fixture
async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT')
yield engine
await engine.dispose()
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
async with postgres_engine.begin() as conn:
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
except Exception:
pass
@@ -83,9 +83,7 @@ class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp baseline on existing tables sets correct revision.
@@ -106,9 +104,7 @@ class TestPostgreSQLMigrationBaseline:
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Stamp on empty database (no tables) still sets revision.
@@ -125,9 +121,7 @@ class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Upgrade from baseline to head applies all migrations.
@@ -149,14 +143,12 @@ class TestPostgreSQLMigrationUpgrade:
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration (0005 for current state)
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
assert rev is not None, 'Expected a revision after upgrade'
# Head should be the latest migration.
assert rev.startswith('0006'), f'Expected head to be 0006_*, got {rev}'
@pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version):
"""
Running upgrade to head multiple times is idempotent.
@@ -180,7 +172,7 @@ class TestPostgreSQLMigrationUpgrade:
await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
class TestPostgreSQLMigrationGetCurrent:
@@ -199,7 +191,7 @@ class TestPostgreSQLMigrationGetCurrent:
# No stamp - should return None
rev = await get_alembic_current(postgres_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
assert rev is None, f'Expected None for unstamped DB, got {rev}'
@pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision(
@@ -214,4 +206,4 @@ class TestPostgreSQLMigrationGetCurrent:
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'
assert rev == '0001_baseline'

View File

@@ -2,4 +2,4 @@
Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner.
"""
"""

View File

@@ -26,6 +26,7 @@ pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -103,6 +104,7 @@ def mock_circular_import_chain():
# ============== FAKE RUNNER ==============
class FakeRunner:
"""Minimal fake runner class for pipeline integration tests.
@@ -117,12 +119,13 @@ class FakeRunner:
self.config = config or {}
self._provider = FakeProvider()
# Instance-level configuration set via class attribute
self._response_text = "fake response"
self._response_text = 'fake response'
self._raise_error = None
@classmethod
def returns(cls, text: str):
"""Create a runner class configured to return specific text."""
# We create a subclass with configured response
class ConfiguredRunner(cls):
name = cls.name
@@ -132,11 +135,13 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._response_text = text
return ConfiguredRunner
@classmethod
def raises(cls, error: Exception):
"""Create a runner class configured to raise an error."""
class ConfiguredRunner(cls):
name = cls.name
_response_text = None
@@ -145,6 +150,7 @@ class FakeRunner:
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._raise_error = error
return ConfiguredRunner
async def run(self, query):
@@ -161,6 +167,7 @@ class FakeRunner:
# ============== PIPELINE APP FIXTURE ==============
@pytest.fixture
def pipeline_app():
"""
@@ -187,6 +194,7 @@ def pipeline_app():
def __init__(self, name, messages):
self.name = name
self.messages = messages
def copy(self):
return MockPrompt(self.name, list(self.messages))
@@ -237,14 +245,17 @@ def fake_platform_adapter():
@pytest.fixture
def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner
# ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests."""
return {
@@ -273,6 +284,7 @@ def create_minimal_pipeline_config():
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name):
"""
Helper to handle the coroutine -> async_generator pattern.
@@ -296,6 +308,7 @@ async def collect_processor_results(processor, query, stage_name):
# ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain."""
@@ -337,7 +350,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
# Create query with adapter
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -365,7 +378,7 @@ class TestPreProcessorStage:
adapter, platform = fake_platform_adapter
query = text_query("test message content")
query = text_query('test message content')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -396,11 +409,11 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Set fake runner that returns pong
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
@@ -414,6 +427,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -432,7 +446,7 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -445,6 +459,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -462,13 +477,13 @@ class TestProcessorStage:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Create reply chain
reply_chain = text_chain("plugin response")
reply_chain = text_chain('plugin response')
# Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock()
@@ -479,6 +494,7 @@ class TestProcessorStage:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -502,7 +518,7 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded'))
set_fake_runner(fake_runner)
# Create query with exception handling config
@@ -510,7 +526,7 @@ class TestRunnerExceptionFlow:
config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -523,6 +539,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -541,14 +558,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error'))
set_fake_runner(fake_runner)
# Create query with show-error mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -561,6 +578,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -578,14 +596,14 @@ class TestRunnerExceptionFlow:
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception("Hidden error"))
fake_runner = FakeRunner().raises(Exception('Hidden error'))
set_fake_runner(fake_runner)
# Create query with hide mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide'
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = config
@@ -598,6 +616,7 @@ class TestRunnerExceptionFlow:
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
@@ -623,7 +642,7 @@ class TestSendResponseBackStage:
adapter, platform = fake_platform_adapter
# Create query with response message
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -666,12 +685,12 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Set fake runner
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
set_fake_runner(fake_runner)
# Create query
config = create_minimal_pipeline_config()
query = text_query("ping")
query = text_query('ping')
query.adapter = adapter
query.pipeline_config = config
query.resp_messages = []
@@ -690,7 +709,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -711,6 +730,7 @@ class TestStageChainIntegration:
# Build resp_message_chain from resp_messages
from tests.factories.message import text_chain
for resp_msg in query.resp_messages:
if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content))
@@ -737,7 +757,7 @@ class TestStageChainIntegration:
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query = text_query('hello')
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
@@ -754,7 +774,7 @@ class TestStageChainIntegration:
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
@@ -775,4 +795,4 @@ class TestStageChainIntegration:
assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages
assert len(query.resp_messages) == 0
assert len(query.resp_messages) == 0

View File

@@ -3,4 +3,4 @@ Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q
"""
"""

View File

@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
assert app.instance_config is not None
# Verify default config
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data["command"]["enable"] is True
assert app.instance_config.data['command']['prefix'] == ['/', '!']
assert app.instance_config.data['command']['enable'] is True
@pytest.mark.asyncio
async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response."""
provider = FakeProvider(default_response="test response")
provider = FakeProvider(default_response='test response')
# Create mock model with provider
model = fake_model(provider=provider)
# Create a simple query
query = text_query("hello")
query = text_query('hello')
# Simulate invoke
result = await provider.invoke_llm(
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
)
assert result is not None
assert result.role == "assistant"
assert result.content == "test response"
assert result.role == 'assistant'
assert result.content == 'test response'
@pytest.mark.asyncio
async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong()
model = fake_model(provider=provider)
query = text_query("ping")
query = text_query('ping')
result = await provider.invoke_llm(
query=query,
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
@pytest.mark.asyncio
async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(["Hello", " World"])
provider = FakeProvider().returns_streaming(['Hello', ' World'])
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
chunks = []
# invoke_llm_stream returns an async generator, don't await it
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].content == "Hello"
assert chunks[1].content == " World"
assert chunks[0].content == 'Hello'
assert chunks[1].content == ' World'
assert chunks[1].is_final is True
@pytest.mark.asyncio
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(TimeoutError, match="Provider timeout"):
with pytest.raises(TimeoutError, match='Provider timeout'):
await provider.invoke_llm(
query=query,
model=model,
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
"""Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit()
model = fake_model(provider=provider)
query = text_query("hello")
query = text_query('hello')
with pytest.raises(Exception, match="Rate limit exceeded"):
with pytest.raises(Exception, match='Rate limit exceeded'):
await provider.invoke_llm(
query=query,
model=model,
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments."""
provider = FakeProvider()
model = fake_model(name="gpt-4", provider=provider)
query = text_query("hello")
model = fake_model(name='gpt-4', provider=provider)
query = text_query('hello')
await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "hello"}],
funcs=[{"name": "test_func"}],
extra_args={"temperature": 0.7},
messages=[{'role': 'user', 'content': 'hello'}],
funcs=[{'name': 'test_func'}],
extra_args={'temperature': 0.7},
)
captured = provider.get_captured_requests()
assert len(captured) == 1
assert captured[0]["model"] == "gpt-4"
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]["extra_args"] == {"temperature": 0.7}
assert captured[0]['model'] == 'gpt-4'
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}]
assert captured[0]['funcs'] == [{'name': 'test_func'}]
assert captured[0]['extra_args'] == {'temperature': 0.7}
@pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id="test-bot")
query = text_query("hello")
platform = FakePlatform(bot_account_id='test-bot')
query = text_query('hello')
# Simulate sending reply
from tests.factories.message import text_chain
reply_chain = text_chain("response text")
reply_chain = text_chain('response text')
event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False)
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
# Verify captured
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert outbound[0]["message"] == reply_chain
assert outbound[0]['type'] == 'reply'
assert outbound[0]['message'] == reply_chain
@pytest.mark.asyncio
async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_friend_message(
text="hello bot",
text='hello bot',
sender_id=12345,
nickname="TestUser",
nickname='TestUser',
)
assert event.type == "FriendMessage"
assert event.type == 'FriendMessage'
assert event.sender.id == 12345
assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == "hello bot"
assert event.sender.nickname == 'TestUser'
assert str(event.message_chain) == 'hello bot'
@pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
event = platform.create_group_message(
text="hello everyone",
text='hello everyone',
sender_id=12345,
group_id=99999,
mention_bot=True,
)
assert event.type == "GroupMessage"
assert event.type == 'GroupMessage'
assert event.sender.id == 12345
assert event.group.id == 99999
@@ -220,54 +220,57 @@ class TestFakeMessageFlow:
async def test_query_factories_basic(self):
"""Test basic query factory functions."""
# Text query
q1 = text_query("hello world")
assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == "hello world"
q1 = text_query('hello world')
assert q1.launcher_type.value == 'person'
assert str(q1.message_chain) == 'hello world'
# Group query
from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
assert q2.launcher_type.value == "group"
q2 = group_text_query('hello group', group_id=88888)
assert q2.launcher_type.value == 'group'
assert q2.launcher_id == 88888
# Command query
from tests.factories import command_query
q3 = command_query("help", prefix="/")
assert str(q3.message_chain) == "/help"
q3 = command_query('help', prefix='/')
assert str(q3.message_chain) == '/help'
# Mention query
from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
assert q4.launcher_type.value == "group"
q4 = mention_query('hi', target='test-bot', group_id=77777)
assert q4.launcher_type.value == 'group'
@pytest.mark.asyncio
async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure()
query = text_query("hello")
query = text_query('hello')
from tests.factories.message import text_chain
with pytest.raises(Exception, match="Platform send failure"):
with pytest.raises(Exception, match='Platform send failure'):
await platform.reply_message(
query.message_event,
text_chain("response"),
text_chain('response'),
)
@pytest.mark.asyncio
async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id="bot-123")
platform = FakePlatform(bot_account_id='bot-123')
adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == "bot-123"
assert adapter.bot_account_id == 'bot-123'
assert adapter._fake_platform is platform
# Test reply_message is wired
from tests.factories.message import text_chain
query = text_query("test")
await adapter.reply_message(query.message_event, text_chain("response"))
query = text_query('test')
await adapter.reply_message(query.message_event, text_chain('response'))
# Verify platform captured it
assert len(platform.get_outbound_messages()) == 1
@@ -293,18 +296,18 @@ class TestMessageFlowIntegration:
Note: This does NOT run actual LangBot pipeline stages.
"""
# Setup
platform = FakePlatform(bot_account_id="test-bot")
platform = FakePlatform(bot_account_id='test-bot')
provider = fake_provider_pong()
model = fake_model(provider=provider)
# Create inbound message
query = text_query("ping")
query = text_query('ping')
# Simulate provider processing
response = await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "ping"}],
messages=[{'role': 'user', 'content': 'ping'}],
funcs=[],
extra_args={},
)
@@ -321,16 +324,16 @@ class TestMessageFlowIntegration:
# Verify platform captured outbound
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
assert outbound[0]['type'] == 'reply'
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(["Hello", " there"])
provider = FakeProvider().returns_streaming(['Hello', ' there'])
model = fake_model(provider=provider)
query = text_query("hi")
query = text_query('hi')
chunks = []
async for chunk in provider.invoke_llm_stream(
@@ -344,8 +347,8 @@ class TestMessageFlowIntegration:
# Verify streaming worked
assert len(chunks) == 2
full_content = "".join(c.content for c in chunks)
assert full_content == "Hello there"
full_content = ''.join(c.content for c in chunks)
assert full_content == 'Hello there'
# Verify platform supports streaming
assert await platform.is_stream_output_supported() is True
assert await platform.is_stream_output_supported() is True

View File

@@ -15,22 +15,12 @@ import pathlib
# Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py'
def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
tree = ast.parse(source)
@@ -40,27 +30,26 @@ def test_no_exec_call_in_system_controller():
if isinstance(node, ast.Call):
func = node.func
# Match bare exec() call
if isinstance(func, ast.Name) and func.id == "exec":
if isinstance(func, ast.Name) and func.id == 'exec':
exec_calls.append(node.lineno)
assert len(exec_calls) == 0, (
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().'
)
def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered."""
with open(VULN_FILE, "r") as f:
with open(VULN_FILE, 'r') as f:
source = f.read()
assert "debug/exec" not in source, (
"The /debug/exec route still exists in system.py. "
"This endpoint allows arbitrary code execution and must be removed."
assert 'debug/exec' not in source, (
'The /debug/exec route still exists in system.py. '
'This endpoint allows arbitrary code execution and must be removed.'
)
if __name__ == "__main__":
if __name__ == '__main__':
test_no_exec_call_in_system_controller()
test_no_debug_exec_route()
print("All tests passed!")
print('All tests passed!')

View File

@@ -1 +1 @@
"""Unit tests for LangBot API HTTP service layer."""
"""Unit tests for LangBot API HTTP service layer."""

View File

@@ -13,4 +13,4 @@ Does NOT:
- Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application.
"""
"""

View File

@@ -132,9 +132,7 @@ class TestApiKeyServiceCreateApiKey:
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}]
assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key'

View File

@@ -303,13 +303,7 @@ class TestBotServiceCreateBot:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
@@ -318,9 +312,7 @@ class TestBotServiceCreateBot:
bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'})
service = BotService(ap)
@@ -352,6 +344,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -362,9 +355,7 @@ class TestBotServiceCreateBot:
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'})
service = BotService(ap)
@@ -397,6 +388,7 @@ class TestBotServiceCreateBot:
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -492,6 +484,7 @@ class TestBotServiceUpdateBot:
pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -582,10 +575,9 @@ class TestBotServiceListEventLogs:
# Mock runtime bot with logger
runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock(return_value=(
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
5
))
runtime_bot.logger.get_logs = AsyncMock(
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5)
)
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
@@ -646,11 +638,7 @@ class TestBotServiceSendMessage:
service = BotService(ap)
# Execute with valid message chain format
message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]}
# Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:

View File

@@ -6,6 +6,7 @@ Tests cover:
- Knowledge engine discovery
- File operations
"""
from __future__ import annotations
import pytest
@@ -52,9 +53,7 @@ class TestGetKnowledgeBases:
"""Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}])
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
@@ -83,9 +82,7 @@ class TestGetKnowledgeBase:
"""Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'})
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1')
@@ -153,9 +150,7 @@ class TestCreateKnowledgeBase:
service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'})
# Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
@@ -170,20 +165,21 @@ class TestUpdateKnowledgeBase:
"""Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'})
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields
await service.update_knowledge_base('kb1', {
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
})
await service.update_knowledge_base(
'kb1',
{
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
},
)
# Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args
@@ -288,9 +284,7 @@ class TestListKnowledgeEngines:
"""Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
@@ -386,12 +380,10 @@ class TestGetEngineSchemas:
"""Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error'))
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert result == {}
mock_app.logger.warning.assert_called_once()
mock_app.logger.warning.assert_called_once()

View File

@@ -174,9 +174,7 @@ class TestMaintenanceServiceGetStorageAnalysis:
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
@@ -292,12 +290,8 @@ class TestMaintenanceServiceGetStorageAnalysis:
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[
{'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}])
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}])
# Execute
result = await service.get_storage_analysis()
@@ -367,6 +361,7 @@ class TestMaintenanceServiceBinaryStorageStats:
size_result = _create_mock_result(scalar_value=5000)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +391,7 @@ class TestMaintenanceServiceBinaryStorageStats:
count_result = _create_mock_result(scalar_value=5)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -821,4 +817,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included
assert 'path' in result[0]
assert 'path' in result[0]

View File

@@ -186,13 +186,7 @@ class TestMCPServiceCreateMCPServer:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}}
ap.plugin_connector = SimpleNamespace()
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
@@ -252,6 +246,7 @@ class TestMCPServiceCreateMCPServer:
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -361,6 +356,7 @@ class TestMCPServiceUpdateMCPServer:
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -394,6 +390,7 @@ class TestMCPServiceUpdateMCPServer:
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -432,6 +429,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock for: first select -> update -> second select (for updated server)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -465,6 +463,7 @@ class TestMCPServiceUpdateMCPServer:
# Mock execute for select and update
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -499,6 +498,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Server to Delete')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -530,6 +530,7 @@ class TestMCPServiceDeleteMCPServer:
server = _create_mock_mcp_server(name='Not in Sessions')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -559,6 +560,7 @@ class TestMCPServiceDeleteMCPServer:
# No server found
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -596,9 +598,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=123)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123))
service = MCPService(ap)
@@ -634,9 +634,7 @@ class TestMCPServiceTestMCPServer:
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=456)
)
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456))
service = MCPService(ap)
@@ -645,4 +643,4 @@ class TestMCPServiceTestMCPServer:
# Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456
assert task_id == 456

View File

@@ -167,6 +167,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([])
call_count = 0
async def mock_execute(query):
return mock_result if call_count == 0 else mock_provider_result
@@ -200,6 +201,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -239,6 +241,7 @@ class TestLLMModelsServiceGetLLMModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -279,6 +282,7 @@ class TestLLMModelsServiceGetLLMModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -337,9 +341,7 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'model-1', 'name': 'Model 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'})
service = LLMModelsService(ap)
@@ -371,12 +373,14 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
})
model_uuid = await service.create_llm_model(
{
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -400,13 +404,16 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}, preserve_uuid=True)
model_uuid = await service.create_llm_model(
{
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
},
preserve_uuid=True,
)
# Verify
assert model_uuid == 'preserved-uuid'
@@ -459,12 +466,14 @@ class TestLLMModelsServiceCreateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_llm_model({
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
})
await service.create_llm_model(
{
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
}
)
async def test_create_llm_model_with_provider_data(self):
"""Creates provider when provider data provided."""
@@ -490,16 +499,18 @@ class TestLLMModelsServiceCreateLLMModel:
service = LLMModelsService(ap)
# Execute - with provider data (no UUID)
result_uuid = await service.create_llm_model({
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
})
result_uuid = await service.create_llm_model(
{
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
}
)
# Verify - provider_service was called and UUID generated
ap.provider_service.find_or_create_provider.assert_called_once()
@@ -525,11 +536,14 @@ class TestLLMModelsServiceUpdateLLMModel:
service = LLMModelsService(ap)
# Execute
await service.update_llm_model('existing-uuid', {
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
})
await service.update_llm_model(
'existing-uuid',
{
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
},
)
# Verify - remove and load called
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
@@ -549,10 +563,13 @@ class TestLLMModelsServiceUpdateLLMModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.update_llm_model('model-uuid', {
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
})
await service.update_llm_model(
'model-uuid',
{
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
},
)
async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args."""
@@ -618,9 +635,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'})
service = EmbeddingModelsService(ap)
@@ -643,6 +658,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -683,6 +699,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -742,11 +759,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
service = EmbeddingModelsService(ap)
# Execute
model_uuid = await service.create_embedding_model({
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_embedding_model(
{
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -767,11 +786,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_embedding_model({
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_embedding_model(
{
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
@@ -829,6 +850,7 @@ class TestRerankModelsServiceGetRerankModels:
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -869,6 +891,7 @@ class TestRerankModelsServiceGetRerankModel:
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -928,11 +951,13 @@ class TestRerankModelsServiceCreateRerankModel:
service = RerankModelsService(ap)
# Execute
model_uuid = await service.create_rerank_model({
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
model_uuid = await service.create_rerank_model(
{
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
}
)
# Verify
assert model_uuid is not None
@@ -952,11 +977,13 @@ class TestRerankModelsServiceCreateRerankModel:
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_rerank_model({
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
await service.create_rerank_model(
{
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
}
)
class TestRerankModelsServiceDeleteRerankModel:
@@ -995,9 +1022,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'})
service = EmbeddingModelsService(ap)
@@ -1022,9 +1047,7 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'})
service = RerankModelsService(ap)
@@ -1042,14 +1065,10 @@ class TestValidateProviderSupports:
def _make_ap(requester_name: str, support_type):
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
manifest = SimpleNamespace(spec={'support_type': support_type})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester=requester_name)
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest
if name == requester_name
else None,
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None,
)
return SimpleNamespace(model_mgr=model_mgr)
@@ -1066,9 +1085,7 @@ class TestValidateProviderSupports:
async def test_allows_when_support_type_missing(self):
# Manifest without support_type must not block (backward compatible)
manifest = SimpleNamespace(spec={})
runtime_provider = SimpleNamespace(
provider_entity=SimpleNamespace(requester='legacy')
)
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy'))
model_mgr = SimpleNamespace(
provider_dict={'p1': runtime_provider},
get_available_requester_manifest_by_name=lambda name: manifest,

View File

@@ -215,13 +215,7 @@ class TestPipelineServiceCreatePipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -229,9 +223,7 @@ class TestPipelineServiceCreatePipeline:
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'})
service = PipelineService(ap)
@@ -258,14 +250,14 @@ class TestPipelineServiceCreatePipeline:
# Mock persistence for insert
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
# Mock the file read for default config - patch at the utils module level
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
# Verify
@@ -286,7 +278,9 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
service.get_pipeline = AsyncMock(
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
)
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
@@ -296,7 +290,9 @@ class TestPipelineServiceCreatePipeline:
# Mock the file read
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
# Verify - execute was called
@@ -316,10 +312,12 @@ class TestPipelineServiceCreatePipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
})
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
}
)
insert_params = []
@@ -339,7 +337,9 @@ class TestPipelineServiceCreatePipeline:
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
with patch(
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
):
await service.create_pipeline({'name': 'New Pipeline'})
assert len(insert_params) == 1
@@ -353,6 +353,7 @@ class TestPipelineServiceCreatePipeline:
class _MockResultWithBots:
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
def __init__(self, bots_list):
self._bots_list = bots_list
@@ -428,6 +429,7 @@ class TestPipelineServiceUpdatePipeline:
# 1. UPDATE (line 125) - returns Mock (no result needed)
# 2. SELECT bots (line 136) - returns bot_result with .all()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -528,13 +530,7 @@ class TestPipelineServiceCopyPipeline:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
@@ -542,10 +538,12 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
# Mock get_pipelines to return 2 pipelines
service.get_pipelines = AsyncMock(return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
])
service.get_pipelines = AsyncMock(
return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
]
)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'):
@@ -642,9 +640,7 @@ class TestPipelineServiceCopyPipeline:
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'copy-uuid', 'is_default': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False})
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
@@ -681,11 +677,10 @@ class TestPipelineServiceUpdatePipelineExtensions:
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
)
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []})
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -700,7 +695,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -711,7 +706,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
},
}
)
@@ -738,6 +733,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
original_pipeline = _create_mock_pipeline()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -752,7 +748,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
'extensions_preferences': {
'enable_all_mcp_servers': False,
'mcp_servers': ['mcp-server-1'],
}
},
}
)
@@ -794,6 +790,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1

View File

@@ -245,12 +245,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
provider_uuid = await service.create_provider({
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
provider_uuid = await service.create_provider(
{
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - UUID is generated
assert provider_uuid is not None
@@ -274,12 +276,14 @@ class TestModelProviderServiceCreateProvider:
service = ModelProviderService(ap)
# Execute
result_uuid = await service.create_provider({
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
result_uuid = await service.create_provider(
{
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Verify - provider added to runtime dict and UUID generated
ap.model_mgr.load_provider.assert_called_once()
@@ -302,10 +306,13 @@ class TestModelProviderServiceUpdateProvider:
service = ModelProviderService(ap)
# Execute
await service.update_provider('existing-uuid', {
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
})
await service.update_provider(
'existing-uuid',
{
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
},
)
# Verify - reload called
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
@@ -364,6 +371,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=None)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -396,6 +404,7 @@ class TestModelProviderServiceDeleteProvider:
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -454,6 +463,7 @@ class TestModelProviderServiceGetProviderModelCounts:
rerank_result.scalar = Mock(return_value=1)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -637,9 +647,7 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
await service.update_space_model_provider_api_keys('space-api-key')
# Verify - update and reload called for Space provider UUID
ap.model_mgr.reload_provider.assert_called_once_with(
'00000000-0000-0000-0000-000000000000'
)
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000')
class TestModelProviderServiceScanProviderModels:
@@ -795,9 +803,7 @@ class TestModelProviderServiceScanProviderModels:
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
runtime_provider.requester.scan_models = AsyncMock(
side_effect=NotImplementedError('scan not supported')
)
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported'))
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
service = ModelProviderService(ap)
@@ -848,9 +854,7 @@ class TestModelProviderServiceScanProviderModels:
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing LLM model
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
return_value=[{'name': 'Existing Model'}]
)
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}])
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
@@ -863,4 +867,4 @@ class TestModelProviderServiceScanProviderModels:
assert existing_model['already_added'] is True
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
assert new_model['already_added'] is False
assert new_model['already_added'] is False

View File

@@ -393,14 +393,16 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -429,10 +431,12 @@ class TestSpaceServiceRefreshToken:
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 1,
'msg': 'Invalid refresh token',
})
mock_response.json = AsyncMock(
return_value={
'code': 1,
'msg': 'Invalid refresh token',
}
)
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
@@ -489,14 +493,16 @@ class TestSpaceServiceExchangeOAuthCode:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -555,13 +561,15 @@ class TestSpaceServiceGetUserInfoRaw:
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -669,27 +677,29 @@ class TestSpaceServiceGetModels:
# Mock HTTP response with proper model data matching SpaceModel schema
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
mock_response.json = AsyncMock(
return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
},
}
})
)
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
@@ -775,4 +785,4 @@ class TestSpaceServiceCreditsCache:
# Verify - cache updated
assert result == 500
assert 'test@example.com' in service._credits_cache
assert service._credits_cache['test@example.com'][0] == 500
assert service._credits_cache['test@example.com'][0] == 500

View File

@@ -495,6 +495,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -565,6 +566,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
@@ -605,4 +607,4 @@ class TestUserServiceCreateUserLock:
# Verify lock exists
assert hasattr(service, '_create_user_lock')
assert service._create_user_lock is not None
assert service._create_user_lock is not None

View File

@@ -132,6 +132,7 @@ class TestWebhookServiceCreateWebhook:
# execute_async returns different results
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -181,6 +182,7 @@ class TestWebhookServiceCreateWebhook:
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -217,6 +219,7 @@ class TestWebhookServiceCreateWebhook:
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
@@ -225,9 +228,7 @@ class TestWebhookServiceCreateWebhook:
return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'id': 1, 'enabled': False}
)
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False})
service = WebhookService(ap)
@@ -503,4 +504,4 @@ class TestWebhookServiceGetEnabledWebhooks:
result = await service.get_enabled_webhooks()
# Verify - should be empty (SQL would filter disabled)
assert result == []
assert result == []

View File

@@ -407,7 +407,9 @@ def test_box_service_forced_template_ignores_pipeline_config():
launcher_type='person',
launcher_id='test_user',
sender_id='test_user',
pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
pipeline_config={
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
},
)
assert service.resolve_box_session_id(query) == 'global'
@@ -1527,9 +1529,7 @@ class TestBuildSkillExtraMounts:
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
]
# No skill is dropped, so no "missing" warning should be logged.
assert not any(
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
)
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list)
def test_skips_skill_with_empty_package_root(self):
logger = Mock()

View File

@@ -54,9 +54,7 @@ def test_classify_python_workspace_detects_package_and_requirements():
def test_wrap_python_command_with_env_contains_bootstrap_and_command():
command = wrap_python_command_with_env('python script.py')
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
assert 'kill -0 "$_LB_LOCK_OWNER"' in command
assert 'python -m venv "$_LB_VENV_DIR"' in command
assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command
assert command.rstrip().endswith('python script.py')

View File

@@ -1 +1 @@
# Unit tests for command module
# Unit tests for command module

View File

@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
# Should yield CommandNotFoundError (no such command registered)
assert len(results) == 1
assert results[0].error is not None
assert results[0].error is not None

View File

@@ -197,6 +197,7 @@ class TestCommandOperatorBase:
op = TestOperator(None)
# Should not raise
import asyncio
asyncio.get_event_loop().run_until_complete(op.initialize())
def test_execute_is_abstract(self):
@@ -299,4 +300,4 @@ class TestMultipleOperators:
yield None
assert AdminOperator.lowest_privilege == 2
assert SubOperator.lowest_privilege == 1
assert SubOperator.lowest_privilege == 1

View File

@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
@pytest.mark.asyncio
async def test_valid_yaml_loads(self, tmp_path):
"""Valid YAML config should load correctly."""
config_file = tmp_path / "test_config.yaml"
config_file = tmp_path / 'test_config.yaml'
# Write valid YAML
config_file.write_text("""
@@ -51,7 +51,7 @@ settings:
@pytest.mark.asyncio
async def test_invalid_yaml_raises_error(self, tmp_path):
"""Invalid YAML should raise clear error."""
config_file = tmp_path / "invalid.yaml"
config_file = tmp_path / 'invalid.yaml'
# Write invalid YAML (unclosed bracket)
config_file.write_text("""
@@ -67,13 +67,13 @@ settings:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await yaml_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_config_creates_from_template(self, tmp_path):
"""Missing config file should be created from template."""
config_file = tmp_path / "new_config.yaml"
config_file = tmp_path / 'new_config.yaml'
# File doesn't exist yet
assert not config_file.exists()
@@ -92,7 +92,7 @@ settings:
@pytest.mark.asyncio
async def test_template_completion(self, tmp_path):
"""Config should be completed with template defaults."""
config_file = tmp_path / "partial.yaml"
config_file = tmp_path / 'partial.yaml'
# Write partial config missing some template keys
config_file.write_text("""
@@ -115,7 +115,7 @@ name: custom_name
@pytest.mark.asyncio
async def test_yaml_save(self, tmp_path):
"""YAML config can be saved."""
config_file = tmp_path / "save_test.yaml"
config_file = tmp_path / 'save_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -131,7 +131,7 @@ name: custom_name
def test_yaml_save_sync(self, tmp_path):
"""YAML config can be saved synchronously."""
config_file = tmp_path / "sync_save.yaml"
config_file = tmp_path / 'sync_save.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -151,14 +151,18 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_valid_json_loads(self, tmp_path):
"""Valid JSON config should load correctly."""
config_file = tmp_path / "test_config.json"
config_file = tmp_path / 'test_config.json'
# Write valid JSON
config_file.write_text(json.dumps({
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}))
config_file.write_text(
json.dumps(
{
'name': 'json_app',
'version': '1.0',
'settings': {'debug': True, 'port': 8080},
}
)
)
json_file = JSONConfigFile(
str(config_file),
@@ -174,7 +178,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_invalid_json_raises_error(self, tmp_path):
"""Invalid JSON should raise clear error."""
config_file = tmp_path / "invalid.json"
config_file = tmp_path / 'invalid.json'
# Write invalid JSON (missing closing brace)
config_file.write_text('{"name": "test", "unclosed": ')
@@ -184,13 +188,13 @@ class TestJSONConfigFile:
template_data={'name': 'default'},
)
with pytest.raises(Exception, match="Syntax error"):
with pytest.raises(Exception, match='Syntax error'):
await json_file.load(completion=False)
@pytest.mark.asyncio
async def test_missing_json_creates_from_template(self, tmp_path):
"""Missing JSON file should be created from template."""
config_file = tmp_path / "new_config.json"
config_file = tmp_path / 'new_config.json'
json_file = JSONConfigFile(
str(config_file),
@@ -205,7 +209,7 @@ class TestJSONConfigFile:
@pytest.mark.asyncio
async def test_json_save(self, tmp_path):
"""JSON config can be saved."""
config_file = tmp_path / "save_test.json"
config_file = tmp_path / 'save_test.json'
json_file = JSONConfigFile(
str(config_file),
@@ -226,7 +230,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_load(self, tmp_path):
"""ConfigManager loads config correctly."""
config_file = tmp_path / "manager_test.yaml"
config_file = tmp_path / 'manager_test.yaml'
config_file.write_text('name: managed_app\nversion: "1.0"\n')
yaml_file = YAMLConfigFile(
@@ -243,7 +247,7 @@ class TestConfigManager:
@pytest.mark.asyncio
async def test_config_manager_dump(self, tmp_path):
"""ConfigManager can dump config."""
config_file = tmp_path / "dump_test.yaml"
config_file = tmp_path / 'dump_test.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -260,7 +264,7 @@ class TestConfigManager:
def test_config_manager_dump_sync(self, tmp_path):
"""ConfigManager can dump config synchronously."""
config_file = tmp_path / "sync_dump.yaml"
config_file = tmp_path / 'sync_dump.yaml'
yaml_file = YAMLConfigFile(
str(config_file),
@@ -280,7 +284,7 @@ class TestConfigExists:
def test_yaml_exists_true(self, tmp_path):
"""exists() returns True for existing file."""
config_file = tmp_path / "exists.yaml"
config_file = tmp_path / 'exists.yaml'
config_file.write_text('name: test')
yaml_file = YAMLConfigFile(str(config_file), template_data={})
@@ -288,14 +292,14 @@ class TestConfigExists:
def test_yaml_exists_false(self, tmp_path):
"""exists() returns False for missing file."""
config_file = tmp_path / "missing.yaml"
config_file = tmp_path / 'missing.yaml'
yaml_file = YAMLConfigFile(str(config_file), template_data={})
assert yaml_file.exists() is False
def test_json_exists_true(self, tmp_path):
"""exists() returns True for existing JSON file."""
config_file = tmp_path / "exists.json"
config_file = tmp_path / 'exists.json'
config_file.write_text('{}')
json_file = JSONConfigFile(str(config_file), template_data={})
@@ -303,7 +307,7 @@ class TestConfigExists:
def test_json_exists_false(self, tmp_path):
"""exists() returns False for missing JSON file."""
config_file = tmp_path / "missing.json"
config_file = tmp_path / 'missing.json'
json_file = JSONConfigFile(str(config_file), template_data={})
assert json_file.exists() is False
assert json_file.exists() is False

View File

@@ -1 +1 @@
"""Core module unit tests."""
"""Core module unit tests."""

View File

@@ -4,6 +4,7 @@ Tests cover:
- _get_positive_int_config() validation
- _get_positive_float_config() validation
"""
from __future__ import annotations
from unittest.mock import Mock
@@ -188,4 +189,4 @@ class TestGetPositiveFloatConfig:
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
assert result == 1.5
mock_logger.warning.assert_called_once()
mock_logger.warning.assert_called_once()

View File

@@ -27,6 +27,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert result == []
@@ -46,6 +47,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
assert 'requests' in result
@@ -61,6 +63,7 @@ class TestCheckDeps:
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
import asyncio
result = asyncio.get_event_loop().run_until_complete(check_deps())
# Should include all required_deps keys
@@ -107,6 +110,7 @@ class TestPrecheckPluginDeps:
with patch('os.path.exists', return_value=False):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_not_called()
@@ -129,6 +133,7 @@ class TestPrecheckPluginDeps:
with patch('os.listdir', side_effect=mock_listdir):
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
import asyncio
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])

View File

@@ -7,6 +7,7 @@ Tests cover:
- Dict type skipping
- Missing key creation
"""
from __future__ import annotations
import os
@@ -248,15 +249,8 @@ class TestApplyEnvOverridesToConfig:
"""Test multiple env vars applied in order."""
load_config = get_load_config_module()
cfg = {
'system': {'name': 'default', 'enable': True},
'concurrency': {'pipeline': 5}
}
env = {
'SYSTEM__NAME': 'custom',
'SYSTEM__ENABLE': 'false',
'CONCURRENCY__PIPELINE': '10'
}
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}}
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'}
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
@@ -287,4 +281,4 @@ class TestApplyEnvOverridesToConfig:
with patch.dict(os.environ, env, clear=True):
result = load_config._apply_env_overrides_to_config(cfg)
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'

View File

@@ -175,4 +175,4 @@ class TestPreregisteredStages:
pass
for key in preregistered_stages:
assert isinstance(key, str)
assert isinstance(key, str)

View File

@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import_isolation to break circular import chains.
"""
from __future__ import annotations
import pytest
@@ -19,15 +20,17 @@ from typing import Generator
class MockLifecycleControlScopeEnum:
"""Mock enum value for LifecycleControlScope with .value attribute."""
def __init__(self, value: str):
self.value = value
def __repr__(self):
return f"LifecycleControlScope.{self.value.upper()}"
return f'LifecycleControlScope.{self.value.upper()}'
class MockLifecycleControlScope:
"""Mock enum for LifecycleControlScope."""
APPLICATION = MockLifecycleControlScopeEnum('application')
PLATFORM = MockLifecycleControlScopeEnum('platform')
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
@@ -40,17 +43,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
# Mock modules that cause circular imports
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
mock_app = MagicMock()
mock_importutil = MagicMock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
mock_http_controller = MagicMock()
mock_rag_mgr = MagicMock()
mocks = {
'langbot.pkg.core.entities': mock_entities,
'langbot.pkg.core.app': mock_app,
@@ -58,26 +61,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
'langbot.pkg.utils.importutil': mock_importutil,
}
# Save original state
saved = {}
for name in mocks:
if name in sys.modules:
saved[name] = sys.modules[name]
# Clear taskmgr to force re-import
taskmgr_name = 'langbot.pkg.core.taskmgr'
if taskmgr_name in sys.modules:
saved[taskmgr_name] = sys.modules[taskmgr_name]
try:
# Apply mocks
for name, module in mocks.items():
sys.modules[name] = module
# Clear taskmgr
sys.modules.pop(taskmgr_name, None)
yield
finally:
# Restore
@@ -86,7 +89,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
sys.modules[name] = saved[name]
else:
sys.modules.pop(name, None)
if taskmgr_name in saved:
sys.modules[taskmgr_name] = saved[taskmgr_name]
else:
@@ -97,6 +100,7 @@ def get_taskmgr_classes():
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
return TaskContext, TaskWrapper, AsyncTaskManager
@@ -194,9 +198,10 @@ class TestTaskContext:
"""Test TaskContext.placeholder() returns singleton."""
with isolated_taskmgr_import():
from langbot.pkg.core.taskmgr import TaskContext
# Reset global placeholder
import langbot.pkg.core.taskmgr as taskmgr_module
taskmgr_module.placeholder_context = None
ctx1 = TaskContext.placeholder()
@@ -269,7 +274,8 @@ class TestTaskWrapper:
return 'result'
wrapper = TaskWrapper(
mock_app, immediate_coro(),
mock_app,
immediate_coro(),
name='test_task',
label='Test Task',
)
@@ -414,7 +420,7 @@ class TestAsyncTaskManager:
async def test_cancel_by_scope(self):
"""Test cancel_by_scope cancels matching tasks."""
_, _, AsyncTaskManager = get_taskmgr_classes()
mock_app = create_mock_app()
manager = AsyncTaskManager(mock_app)
@@ -422,16 +428,10 @@ class TestAsyncTaskManager:
await asyncio.sleep(10)
# Create task with APPLICATION scope
w1 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.APPLICATION]
)
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION])
# Create task with different scope
w2 = manager.create_task(
long_coro(),
scopes=[MockLifecycleControlScope.PIPELINE]
)
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE])
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)

View File

@@ -15,68 +15,68 @@ class TestI18nString:
def test_create_with_english_only(self):
"""Create I18nString with only English."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
assert i18n.en_US == "Hello"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans is None
def test_create_with_multiple_languages(self):
"""Create I18nString with multiple languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
)
assert i18n.en_US == "Hello"
assert i18n.zh_Hans == "你好"
assert i18n.zh_Hant == "你好"
assert i18n.ja_JP == "こんにちは"
assert i18n.en_US == 'Hello'
assert i18n.zh_Hans == '你好'
assert i18n.zh_Hant == '你好'
assert i18n.ja_JP == 'こんにちは'
def test_to_dict_with_english_only(self):
"""to_dict returns only non-None fields."""
i18n = I18nString(en_US="Hello")
i18n = I18nString(en_US='Hello')
result = i18n.to_dict()
assert result == {"en_US": "Hello"}
assert result == {'en_US': 'Hello'}
def test_to_dict_with_multiple_languages(self):
"""to_dict returns all non-None fields."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
en_US='Hello',
zh_Hans='你好',
)
result = i18n.to_dict()
assert result == {"en_US": "Hello", "zh_Hans": "你好"}
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'}
def test_to_dict_excludes_none(self):
"""to_dict excludes None values."""
i18n = I18nString(
en_US="Hello",
en_US='Hello',
zh_Hans=None,
ja_JP="こんにちは",
ja_JP='こんにちは',
)
result = i18n.to_dict()
assert "zh_Hans" not in result
assert "en_US" in result
assert "ja_JP" in result
assert 'zh_Hans' not in result
assert 'en_US' in result
assert 'ja_JP' in result
def test_to_dict_all_languages(self):
"""to_dict with all supported languages."""
i18n = I18nString(
en_US="Hello",
zh_Hans="你好",
zh_Hant="你好",
ja_JP="こんにちは",
th_TH="สวัสดี",
vi_VN="Xin chào",
es_ES="Hola",
en_US='Hello',
zh_Hans='你好',
zh_Hant='你好',
ja_JP='こんにちは',
th_TH='สวัสดี',
vi_VN='Xin chào',
es_ES='Hola',
)
result = i18n.to_dict()
@@ -92,30 +92,30 @@ class TestMetadata:
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test Component"),
name='test-component',
label=I18nString(en_US='Test Component'),
)
assert metadata.name == "test-component"
assert metadata.label.en_US == "Test Component"
assert metadata.name == 'test-component'
assert metadata.label.en_US == 'Test Component'
def test_create_with_all_fields(self):
"""Create Metadata with all optional fields."""
from langbot.pkg.discover.engine import I18nString
metadata = Metadata(
name="test-component",
label=I18nString(en_US="Test"),
description=I18nString(en_US="A test component"),
version="1.0.0",
icon="test-icon",
author="Test Author",
repository="https://github.com/test/repo",
name='test-component',
label=I18nString(en_US='Test'),
description=I18nString(en_US='A test component'),
version='1.0.0',
icon='test-icon',
author='Test Author',
repository='https://github.com/test/repo',
)
assert metadata.version == "1.0.0"
assert metadata.icon == "test-icon"
assert metadata.author == "Test Author"
assert metadata.version == '1.0.0'
assert metadata.icon == 'test-icon'
assert metadata.author == 'Test Author'
class TestComponentManifest:

View File

@@ -7,6 +7,7 @@ Tests cover:
Note: Uses import isolation to break circular import chains.
"""
from __future__ import annotations
import sys
@@ -86,6 +87,7 @@ def get_database_module():
"""Get database module with import isolation."""
with isolated_database_import():
from langbot.pkg.persistence import database
return database
@@ -198,4 +200,4 @@ class TestManagerClassDecorator:
# Create instance to test method (with mock app)
mock_app = Mock()
instance = ManagerWithMethods(mock_app)
assert instance.custom_method() == 'test_value'
assert instance.custom_method() == 'test_value'

View File

@@ -4,6 +4,7 @@ Tests cover:
- execute_async() with mock database
- get_db_engine() with mock database manager
"""
from __future__ import annotations
import pytest
@@ -85,7 +86,7 @@ class TestExecuteAsync:
mock_db.get_engine = Mock(return_value=mock_engine)
mgr.db = mock_db
result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
result = await mgr.execute_async(sqlalchemy.text('SELECT 1'))
# Verify result is the same object returned by execute
assert result is mock_result
@@ -152,4 +153,4 @@ class TestSerializeModelEdgeCases:
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
# Result should be empty dict when all columns masked
assert result == {}
assert result == {}

View File

@@ -5,6 +5,7 @@ Tests cover:
- datetime conversion to isoformat
- masked_columns exclusion
"""
from __future__ import annotations
import datetime

View File

@@ -49,7 +49,7 @@ class TestPendingMessage:
"""PendingMessage should be created with correct fields."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -88,7 +88,7 @@ class TestSessionBuffer:
"""SessionBuffer should accept initial messages."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("hello")
chain2 = text_chain("world")
chain1 = text_chain('hello')
chain2 = text_chain('world')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
# Should contain both messages with separator
merged_str = str(merged.message_chain)
assert "hello" in merged_str
assert "world" in merged_str
assert 'hello' in merged_str
assert 'world' in merged_str
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("first")
chain2 = text_chain("second")
chain1 = text_chain('first')
chain2 = text_chain('second')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()

View File

@@ -15,6 +15,7 @@ from tests.factories import FakeApp
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -36,9 +37,11 @@ def mock_circular_import_chain():
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
@@ -70,9 +73,12 @@ def mock_event_ctx():
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
@@ -87,6 +93,7 @@ def get_chat_handler():
global _chat_handler_module
if _chat_handler_module is None:
from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module
@@ -96,12 +103,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class."""
@@ -188,9 +197,11 @@ class TestChatMessageHandlerReal:
class QuickRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='ok')
@@ -222,9 +233,11 @@ class TestChatMessageHandlerReal:
class SingleRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='response')
@@ -262,9 +275,11 @@ class TestChatHandlerStreaming:
class StreamRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True)
@@ -303,14 +318,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class FailingRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('API error')
yield
@@ -346,14 +366,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class ErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('Custom error')
yield
@@ -386,14 +411,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class HideErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise RuntimeError('hidden')
yield
@@ -433,4 +463,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result

View File

@@ -67,7 +67,11 @@ def make_pipeline_config(**overrides):
for key, value in overrides.items():
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
if (
sub_key in base_config[key]
and isinstance(base_config[key][sub_key], dict)
and isinstance(sub_value, dict)
):
base_config[key][sub_key].update(sub_value)
else:
base_config[key][sub_key] = sub_value
@@ -141,7 +145,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -163,7 +167,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
# Empty message chain
query = text_query("")
query = text_query('')
query.message_chain = platform_message.MessageChain([])
query.pipeline_config = pipeline_config
@@ -185,7 +189,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query(" ") # Only whitespace
query = text_query(' ') # Only whitespace
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -234,7 +238,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -266,7 +270,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -294,7 +298,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("http://example.com")
query = text_query('http://example.com')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -322,7 +326,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("normal message")
query = text_query('normal message')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -343,7 +347,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -368,12 +372,10 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Add a response message
query.resp_messages = [
provider_message.Message(role='assistant', content='Hello back!')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -398,11 +400,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='Response')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -422,7 +422,7 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-string content - use model_construct to bypass validation
# The actual content type could be a list of ContentElement objects
@@ -450,11 +450,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='')
]
query.resp_messages = [provider_message.Message(role='assistant', content='')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
@@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect:
await stage.initialize(pipeline_config)
query = text_query("normal message without prefix")
query = text_query('normal message without prefix')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')

View File

@@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -56,6 +57,7 @@ def mock_event_ctx():
@pytest.fixture
def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute(
text: str | None = 'ok',
error: str | None = None,
@@ -71,7 +73,9 @@ def mock_execute_factory():
ret.image_base64 = image_base64
ret.file_url = file_url
yield ret
return mock_execute
return _create_execute
@@ -86,6 +90,7 @@ def get_command_handler():
global _command_handler_module
if _command_handler_module is None:
from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module
@@ -95,12 +100,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal:
"""Tests for real CommandHandler class."""
@@ -127,6 +134,7 @@ class TestCommandHandlerReal:
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
executed_commands = []
async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text)
ret = Mock()
@@ -334,8 +342,7 @@ class TestCommandHandlerReal:
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:',
image_url='https://example.com/image.png'
text='Here is the image:', image_url='https://example.com/image.png'
)
handler = command.CommandHandler(fake_app)
@@ -393,4 +400,4 @@ class TestCommandHandlerHelper:
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result

View File

@@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="very long response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="short response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-Plain component (Image)
query.resp_message_chain = [
platform_message.MessageChain([
platform_message.Plain(text="short"),
platform_message.Image(url="https://example.com/img.png")
])
platform_message.MessageChain(
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
)
]
result = await stage.process(query, 'LongTextProcessStage')
@@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess:
stage = longtext.LongTextProcessStage(app)
stage.strategy_impl = AsyncMock()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
query.resp_message_chain = []
@@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess:
assert result.new_query is query
stage.strategy_impl.process.assert_not_called()
class TestForwardStrategy:
"""Tests for ForwardComponentStrategy."""
@@ -260,7 +256,7 @@ class TestForwardStrategy:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Create a mock adapter with bot_account_id
mock_adapter = Mock()
@@ -268,10 +264,8 @@ class TestForwardStrategy:
query.adapter = mock_adapter
# Long text exceeding threshold
long_text = "This is a very long response that exceeds the threshold"
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=long_text)])
]
long_text = 'This is a very long response that exceeds the threshold'
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -297,13 +291,13 @@ class TestForwardStrategy:
await strat.initialize()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config()
mock_adapter = Mock()
mock_adapter.bot_account_id = '12345'
query.adapter = mock_adapter
components = await strat.process("test message", query)
components = await strat.process('test message', query)
assert len(components) == 1
assert isinstance(components[0], platform_message.Forward)
@@ -326,14 +320,12 @@ class TestLongTextThreshold:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Text below threshold
short_text = "x" * (threshold - 1)
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=short_text)])
]
short_text = 'x' * (threshold - 1)
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
result = await stage.process(query, 'LongTextProcessStage')

View File

@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
# Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = []
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
trun = trun_cls(app)
break
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),

View File

@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello world")
query = text_query('hello world')
result = await stage.process(query, 'PreProcessor')
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("test message")
query = text_query('test message')
result = await stage.process(query, 'PreProcessor')
@@ -194,13 +194,16 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app)
# Image query with base64
query = image_query(text="look at this", url=None)
query = image_query(text='look at this', url=None)
# Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([
platform_message.Plain(text="look at this"),
platform_message.Image(base64="data:image/png;base64,abc123"),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text='look at this'),
platform_message.Image(base64='data:image/png;base64,abc123'),
]
)
query.message_chain = chain
result = await stage.process(query, 'PreProcessor')
@@ -238,7 +241,7 @@ class TestPreProcessorImageSegment:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = image_query(text="describe this")
query = image_query(text='describe this')
result = await stage.process(query, 'PreProcessor')
@@ -276,7 +279,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
# Set pipeline config with primary model
query.pipeline_config = {
@@ -335,7 +338,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = {
'ai': {
@@ -384,7 +387,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello", sender_id=67890)
query = text_query('hello', sender_id=67890)
result = await stage.process(query, 'PreProcessor')
@@ -421,7 +424,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = group_text_query("hello", group_id=99999)
query = group_text_query('hello', group_id=99999)
result = await stage.process(query, 'PreProcessor')

View File

@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
'safety': {
'rate-limit': {
'window-length': 60, # 60 seconds window
'limitation': 10, # 10 requests per window
'limitation': 10, # 10 requests per window
'strategy': 'drop',
}
}
@@ -75,11 +75,9 @@ class TestFixedWindowAlgo:
# Make requests within limit
for i in range(10):
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
)
assert result is True, f"Request {i+1} should be allowed"
assert result is True, f'Request {i + 1} should be allowed'
@pytest.mark.asyncio
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -91,20 +89,12 @@ class TestFixedWindowAlgo:
# Exhaust the limit
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Next request should be denied
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
assert result is False, "Request exceeding limit should be denied"
assert result is False, 'Request exceeding limit should be denied'
@pytest.mark.asyncio
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -116,20 +106,14 @@ class TestFixedWindowAlgo:
# Exhaust limit for session 1
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session1'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
# Session 2 should still have its own limit
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session2'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
)
assert result is True, "Different session should have independent limit"
assert result is True, 'Different session should have independent limit'
@pytest.mark.asyncio
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
@@ -150,19 +134,11 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request allowed
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result1 is True
# Second request denied
result2 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result2 is False
@pytest.mark.asyncio
@@ -174,11 +150,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request creates container
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
expected_key = 'LauncherTypes.PERSON_12345'
@@ -230,7 +202,7 @@ class TestFixedWindowAlgo:
# New request should be allowed (new window)
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
assert result is True, "New window should allow new requests"
assert result is True, 'New window should allow new requests'
@pytest.mark.asyncio
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
@@ -256,29 +228,21 @@ class TestFixedWindowAlgo:
# First request allowed
start_time = time.time()
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
assert result1 is True
# Exhaust limit
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
# Third request should wait and then succeed
result3 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
elapsed = time.time() - start_time
assert result3 is True, "After wait, request should succeed"
assert result3 is True, 'After wait, request should succeed'
# Should have waited approximately until next window
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
# Note: This is a timing-sensitive test, so we use a generous tolerance
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
@pytest.mark.asyncio
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -289,11 +253,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# release_access is empty in current implementation
await algo.release_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Should not raise or change state
assert 'person_12345' not in algo.containers

View File

@@ -55,7 +55,7 @@ def make_session():
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
query.resp_message_chain = []
results = []
@@ -125,7 +123,7 @@ class TestResponseWrapperCommand:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -133,7 +131,7 @@ class TestResponseWrapperCommand:
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
)
query.resp_messages = [command_resp]
@@ -163,7 +161,7 @@ class TestResponseWrapperPlugin:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -171,7 +169,7 @@ class TestResponseWrapperPlugin:
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
)
query.resp_messages = [plugin_resp]
@@ -211,17 +209,17 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello back!"
assistant_resp.content = 'Hello back!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
)
query.resp_messages = [assistant_resp]
@@ -247,7 +245,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -292,7 +290,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -303,10 +301,10 @@ class TestResponseWrapperAssistant:
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Processing..."
assistant_resp.content = 'Processing...'
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
)
query.resp_messages = [assistant_resp]
@@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello!"
assistant_resp.content = 'Hello!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
)
query.resp_messages = [assistant_resp]
@@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply:
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
@@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Default reply"
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
@@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply:
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert "Custom reply" in str(chain)
assert 'Custom reply' in str(chain)
class TestResponseWrapperVariables:
@@ -452,7 +450,7 @@ class TestResponseWrapperVariables:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
@@ -460,10 +458,10 @@ class TestResponseWrapperVariables:
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello"
assistant_resp.content = 'Hello'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
)
query.resp_messages = [assistant_resp]

View File

@@ -6,6 +6,7 @@ Tests cover:
- RAG methods (ingest, retrieve, schema)
- Disabled plugin early returns
"""
from __future__ import annotations
import pytest
@@ -86,16 +87,12 @@ class TestListPlugins:
return_value=[
{
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Command'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}],
'debug': False,
},
{
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
'components': [
{'manifest': {'manifest': {'kind': 'Tool'}}}
],
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}],
'debug': False,
},
]
@@ -127,9 +124,7 @@ class TestListPlugins:
},
]
)
connector.ap.persistence_mgr.execute_async = AsyncMock(
return_value=Mock(__iter__=lambda self: iter([]))
)
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([])))
result = await connector.list_plugins()
@@ -230,7 +225,8 @@ class TestCallParser:
)
connector.handler.parse_document.assert_called_once_with(
'author', 'parser',
'author',
'parser',
{'mime_type': 'text/plain', 'filename': 'test.txt'},
b'file content',
)
@@ -251,9 +247,7 @@ class TestRAGMethods:
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
connector.handler.rag_ingest_document.assert_called_once_with(
'author', 'engine', {'file': 'test.pdf'}
)
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'})
assert result['status'] == 'success'
@pytest.mark.asyncio
@@ -264,14 +258,16 @@ class TestRAGMethods:
connector.handler = AsyncMock()
connector.handler.retrieve_knowledge = AsyncMock(
return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
return_value={
'results': [
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
]
}
)
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
connector.handler.retrieve_knowledge.assert_called_once_with(
'author', 'engine', '', {'query': 'test'}
)
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'})
assert result == {
'results': [
{
@@ -290,9 +286,7 @@ class TestRAGMethods:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}})
result = await connector.get_rag_creation_schema('author/engine')
@@ -326,9 +320,7 @@ class TestRAGMethods:
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
connector.handler.rag_on_kb_create.assert_called_once_with(
'author', 'engine', 'kb-uuid', {'model': 'test'}
)
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'})
@pytest.mark.asyncio
async def test_rag_on_kb_delete(self):
@@ -354,9 +346,7 @@ class TestRAGMethods:
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
connector.handler.rag_delete_document.assert_called_once_with(
'author', 'engine', 'doc-uuid', 'kb-uuid'
)
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid')
assert result is True
@@ -446,9 +436,7 @@ class TestGetPluginInfo:
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.get_plugin_info = AsyncMock(
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
)
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}})
result = await connector.get_plugin_info('author', 'plugin')
@@ -470,9 +458,7 @@ class TestSetPluginConfig:
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
connector.handler.set_plugin_config.assert_called_once_with(
'author', 'plugin', {'setting': 'value'}
)
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'})
class TestPingPluginRuntime:

View File

@@ -3,6 +3,7 @@
Tests cover:
- _parse_plugin_id() parsing and validation
"""
from __future__ import annotations
import pytest

View File

@@ -6,6 +6,7 @@ Tests cover:
- Handling missing requirements.txt
- Handling empty/malformed requirements.txt
"""
from __future__ import annotations
import zipfile
@@ -82,13 +83,13 @@ class TestExtractDepsMetadata:
"""Test that comments and empty lines are filtered."""
connector_instance = create_mock_connector()
requirements = '''# This is a comment
requirements = """# This is a comment
requests>=2.0
# Another comment
flask==1.0
numpy'''
numpy"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()
@@ -147,9 +148,9 @@ numpy'''
"""Test handling requirements.txt with only comments."""
connector_instance = create_mock_connector()
requirements = '''# Comment 1
requirements = """# Comment 1
# Comment 2
# Comment 3'''
# Comment 3"""
zip_bytes = create_zip_with_requirements(requirements)
task_context = Mock()

View File

@@ -40,11 +40,13 @@ class TestHandlerQueryVariables:
"""Test set_query_var returns error when query not found."""
runtime_handler = make_handler(mock_app)
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'nonexistent-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -58,11 +60,13 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
})
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'test_var',
'value': 'test_value',
}
)
assert response.code == 0
assert mock_query.variables['test_var'] == 'test_value'
@@ -76,10 +80,12 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
'query_id': 'test-query',
'key': 'existing_var',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value](
{
'query_id': 'test-query',
'key': 'existing_var',
}
)
assert response.code == 0
assert response.data == {'value': 'existing_value'}
@@ -93,9 +99,11 @@ class TestHandlerQueryVariables:
mock_app.query_pool.cached_queries['test-query'] = mock_query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
'query_id': 'test-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value](
{
'query_id': 'test-query',
}
)
assert response.code == 0
assert response.data == {'vars': mock_query.variables}
@@ -108,7 +116,7 @@ class TestHandlerRagErrorResponse:
"""Test basic error response creation."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = Exception("test error")
error = Exception('test error')
response = _make_rag_error_response(error, 'TestError')
# ActionResponse is a pydantic model, check message field
@@ -120,13 +128,8 @@ class TestHandlerRagErrorResponse:
"""Test error response with extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = ValueError("invalid input")
response = _make_rag_error_response(
error,
'ValidationError',
field='name',
value='test'
)
error = ValueError('invalid input')
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test')
assert 'ValidationError' in response.message
assert 'field=name' in response.message
@@ -137,7 +140,7 @@ class TestHandlerRagErrorResponse:
"""Test error response includes exception type."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = RuntimeError("connection failed")
error = RuntimeError('connection failed')
response = _make_rag_error_response(error, 'ConnectionError')
assert 'RuntimeError' in response.message
@@ -148,7 +151,7 @@ class TestHandlerRagErrorResponse:
"""Test error response with no extra context."""
from langbot.pkg.plugin.handler import _make_rag_error_response
error = KeyError("missing_key")
error = KeyError('missing_key')
response = _make_rag_error_response(error, 'LookupError')
# No context parts means no brackets

View File

@@ -47,12 +47,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'local',
'install_info': {'path': '/test'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -82,12 +84,14 @@ class TestInitializePluginSettings:
Mock(),
]
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
})
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
'install_source': 'github',
'install_info': {'repo': 'author/name'},
}
)
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 3
@@ -161,9 +165,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'new')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new'))
assert response.code == 0
assert app.persistence_mgr.execute_async.await_count == 2
@@ -203,9 +205,7 @@ class TestSetBinaryStorage:
runtime_handler = make_handler(app)
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
self.payload(b'x')
)
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x'))
assert response.code != 0
assert '1 > 0 bytes' in response.message
@@ -228,10 +228,12 @@ class TestGetPluginSettings:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -255,10 +257,12 @@ class TestGetPluginSettings:
)
app.persistence_mgr.execute_async.return_value = make_result(setting)
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
{
'plugin_author': 'test-author',
'plugin_name': 'test-plugin',
}
)
assert response.code == 0
assert response.data == {
@@ -286,11 +290,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code == 0
assert response.data == {
@@ -303,11 +309,13 @@ class TestGetBinaryStorage:
runtime_handler = make_handler(app)
app.persistence_mgr.execute_async.return_value = make_result()
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
})
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
{
'key': 'test-key',
'owner_type': 'plugin',
'owner': 'test-owner',
}
)
assert response.code != 0
assert 'Storage with key test-key not found' in response.message
@@ -329,9 +337,11 @@ class TestHandlerQueryLookup:
"""Query-bound actions return error when query_id is not cached."""
runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'nonexistent-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'nonexistent-query',
}
)
assert response.code != 0
assert 'nonexistent-query' in response.message
@@ -343,9 +353,11 @@ class TestHandlerQueryLookup:
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
app.query_pool.cached_queries['existing-query'] = query
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
'query_id': 'existing-query',
})
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
{
'query_id': 'existing-query',
}
)
assert response.code == 0
assert response.data == {'bot_uuid': 'test-bot-uuid'}

View File

@@ -4,6 +4,7 @@ Tests cover:
- _make_rag_error_response() helper function
- RuntimeConnectionHandler cleanup_plugin_data method
"""
from __future__ import annotations
import pytest
@@ -23,7 +24,7 @@ class TestMakeRagErrorResponse:
"""Test basic error response creation."""
handler = get_handler_module()
error = ValueError("test error message")
error = ValueError('test error message')
result = handler._make_rag_error_response(error, 'TestError')
# ActionResponse.error() returns code=1 (error status)
@@ -36,7 +37,7 @@ class TestMakeRagErrorResponse:
"""Test that error type is included in message."""
handler = get_handler_module()
error = RuntimeError("something went wrong")
error = RuntimeError('something went wrong')
result = handler._make_rag_error_response(error, 'VectorStoreError')
assert '[VectorStoreError/RuntimeError]' in result.message
@@ -45,7 +46,7 @@ class TestMakeRagErrorResponse:
"""Test that extra context fields are included."""
handler = get_handler_module()
error = Exception("embedding failed")
error = Exception('embedding failed')
result = handler._make_rag_error_response(
error,
'EmbeddingError',
@@ -71,7 +72,7 @@ class TestMakeRagErrorResponse:
"""Test multiple context fields are comma separated."""
handler = get_handler_module()
error = IOError("file not found")
error = IOError('file not found')
result = handler._make_rag_error_response(
error,
'FileServiceError',
@@ -119,9 +120,7 @@ class TestCleanupPluginData:
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
handler_instance.ap = mock_app
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
handler_instance, 'author', 'plugin-name'
)
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name')
# Should have at least 2 calls: one for settings, one for binary storage
assert mock_app.persistence_mgr.execute_async.call_count >= 2
assert mock_app.persistence_mgr.execute_async.call_count >= 2

View File

@@ -88,7 +88,10 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
return provider_message.Message(
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
)
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
"""Return fake rerank results."""
@@ -135,8 +138,10 @@ def mock_app_for_modelmgr():
# Fake persistence manager - returns empty results by default
app.persistence_mgr = SimpleNamespace()
async def default_execute(query):
return _make_mock_result([])
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
# Fake discover engine
@@ -165,9 +170,7 @@ def fake_requester_registry(mock_app_for_modelmgr):
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
app.discover.get_components_by_kind = Mock(
return_value=[fake_component, another_component]
)
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component])
model_mgr = ModelManager(app)
return model_mgr

View File

@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
with pytest.raises(DifyAPIError, match='不支持'):
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
'output': {'misc': {}},
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app
assert runner.ap == mock_app

View File

@@ -1062,9 +1062,7 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.side_effect = (
lambda model: {'max_input_tokens': 131072}
if model == 'moonshot/moonshot-v1-128k'
else {}
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {}
)
assert requester._safe_context_length('moonshot-v1-128k') == 131072

View File

@@ -180,7 +180,7 @@ class TestMCPServerBoxConfig:
assert cfg.host_path is None
assert cfg.host_path_mode == 'ro'
assert cfg.env == {}
assert cfg.startup_timeout_sec == 300
assert cfg.startup_timeout_sec == 120
assert cfg.cpus is None
assert cfg.memory_mb is None
assert cfg.pids_limit is None
@@ -494,84 +494,6 @@ class TestBuildBoxProcessPayload:
assert payload['args'] == ['/opt/other/server.py', '--flag']
# ── Python Workspace Preparation ────────────────────────────────────
class TestPythonWorkspacePreparation:
def test_requirements_workspace_uses_venv_bootstrap(self, mcp_module, tmp_path):
host_path = tmp_path / 'mcp-source'
host_path.mkdir()
(host_path / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8')
command = mcp_module.BoxStdioSessionRuntime.detect_install_command(
str(host_path),
'/workspace/.mcp/u1/workspace',
)
assert command is not None
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
assert 'python -m pip install -r "/workspace/.mcp/u1/workspace/requirements.txt"' in command
assert 'pip install --no-cache-dir -r' not in command
def test_staging_refresh_removes_stale_source_files_but_preserves_runtime_dirs(self, mcp_module, tmp_path):
source = tmp_path / 'source'
source.mkdir()
(source / 'server.py').write_text('print("new")\n', encoding='utf-8')
(source / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8')
(source / '.env').write_text('TOKEN=new\n', encoding='utf-8')
process_root = tmp_path / 'shared' / '.mcp' / 'u1'
workspace = process_root / 'workspace'
(workspace / '.venv' / 'bin').mkdir(parents=True)
(workspace / '.venv' / 'bin' / 'python').write_text('', encoding='utf-8')
(workspace / '.langbot').mkdir()
(workspace / '.langbot' / 'python-env.lock').mkdir()
(workspace / '.env').write_text('TOKEN=old\n', encoding='utf-8')
(workspace / 'server.py').write_text('print("old")\n', encoding='utf-8')
(workspace / 'removed.py').write_text('stale\n', encoding='utf-8')
(workspace / 'removed_dir').mkdir()
(workspace / 'removed_dir' / 'old.txt').write_text('stale\n', encoding='utf-8')
mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace))
assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n'
assert (workspace / 'requirements.txt').read_text(encoding='utf-8') == 'mcp==1.26.0\n'
assert (workspace / '.env').read_text(encoding='utf-8') == 'TOKEN=new\n'
assert not (workspace / 'removed.py').exists()
assert not (workspace / 'removed_dir').exists()
assert (workspace / '.venv' / 'bin' / 'python').exists()
assert (workspace / '.langbot' / 'python-env.lock').is_dir()
def test_staging_refresh_ignores_unlink_race(self, mcp_module, tmp_path, monkeypatch):
mcp_stdio_module = sys.modules['langbot.pkg.provider.tools.loaders.mcp_stdio']
source = tmp_path / 'source'
source.mkdir()
(source / 'server.py').write_text('print("new")\n', encoding='utf-8')
process_root = tmp_path / 'shared' / '.mcp' / 'u1'
workspace = process_root / 'workspace'
workspace.mkdir(parents=True)
stale_file = workspace / 'removed.py'
stale_file.write_text('stale\n', encoding='utf-8')
real_unlink = os.unlink
def unlink_with_race(path):
if os.fspath(path) == str(stale_file):
real_unlink(path)
raise FileNotFoundError(path)
real_unlink(path)
monkeypatch.setattr(mcp_stdio_module.os, 'unlink', unlink_with_race)
mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace))
assert not stale_file.exists()
assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n'
# ── get_runtime_info_dict ───────────────────────────────────────────

View File

@@ -635,7 +635,9 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
model_mgr = fake_requester_registry
@@ -648,7 +650,9 @@ async def test_model_manager_load_llm_model_with_provider(fake_requester_registr
@pytest.mark.asyncio
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_llm_model_with_provider_from_row(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
model_mgr = fake_requester_registry
@@ -661,7 +665,9 @@ async def test_model_manager_load_llm_model_with_provider_from_row(fake_requeste
@pytest.mark.asyncio
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
async def test_model_manager_load_embedding_model_with_provider(
fake_requester_registry, fake_persistence_data, runtime_provider
):
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
model_mgr = fake_requester_registry

View File

@@ -43,6 +43,7 @@ class TestableRequester(requester.ProviderAPIRequester):
remove_think=False,
):
import langbot_plugin.api.entities.builtin.provider.message as provider_message
return provider_message.Message(
role='assistant',
content=[provider_message.ContentElement(type='text', text='Testable response')],
@@ -289,7 +290,9 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
result = await provider.invoke_llm(query, runtime_llm_model, messages)
@@ -330,7 +333,9 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
chunks = []
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
@@ -576,7 +581,9 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
current_stage_name=None,
)
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
messages = [
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
]
with pytest.raises(RequesterError):
await provider.invoke_llm(query, model, messages)

View File

@@ -5,6 +5,7 @@ Tests cover:
- Conversation creation with prompts
- Session concurrency semaphore
"""
from __future__ import annotations
import pytest
@@ -60,11 +61,7 @@ class TestSessionManagerGetSession:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -173,11 +170,7 @@ class TestSessionManagerGetConversation:
"""Create mock app with instance config."""
mock_app = Mock()
mock_app.instance_config = Mock()
mock_app.instance_config.data = {
'concurrency': {
'session': 5
}
}
mock_app.instance_config.data = {'concurrency': {'session': 5}}
return mock_app
@pytest.fixture
@@ -201,17 +194,13 @@ class TestSessionManagerGetConversation:
return query
@pytest.mark.asyncio
async def test_creates_conversation_with_prompt(
self, mock_app_with_config, sample_query, sample_session
):
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session):
"""Test that get_conversation creates conversation with prompt."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
@@ -234,21 +223,15 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
pipeline_uuid = 'pipeline-123'
bot_uuid = 'bot-123'
# First call creates conversation
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
# Second call with same pipeline should return same conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
assert conv1 is conv2
assert len(sample_session.conversations) == 1
@@ -262,36 +245,26 @@ class TestSessionManagerGetConversation:
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
# First call with pipeline1
conv1 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
)
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1')
# Second call with different pipeline should create new conversation
conv2 = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
)
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2')
assert conv1 is not conv2
assert len(sample_session.conversations) == 2
assert sample_session.using_conversation is conv2
@pytest.mark.asyncio
async def test_conversation_has_empty_messages(
self, mock_app_with_config, sample_query, sample_session
):
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session):
"""Test that created conversation has empty messages list."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'You are a helpful assistant.'}
]
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
@@ -300,22 +273,17 @@ class TestSessionManagerGetConversation:
assert conversation.messages == []
@pytest.mark.asyncio
async def test_prompt_messages_from_config(
self, mock_app_with_config, sample_query, sample_session
):
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session):
"""Test that prompt messages are created from prompt_config."""
sessionmgr = get_session_module()
manager = sessionmgr.SessionManager(mock_app_with_config)
prompt_config = [
{'role': 'system', 'content': 'System message'},
{'role': 'user', 'content': 'User message'}
]
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}]
conversation = await manager.get_conversation(
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
)
assert conversation.prompt.name == 'default'
assert len(conversation.prompt.messages) == 2
assert len(conversation.prompt.messages) == 2

View File

@@ -193,29 +193,6 @@ class TestSkillPathHelpers:
assert list(result.keys()) == ['visible']
def test_restore_activated_skills_uses_caller_provided_names_and_visibility(self):
from langbot.pkg.provider.tools.loaders.skill import (
ACTIVATED_SKILLS_KEY,
PIPELINE_BOUND_SKILLS_KEY,
get_activated_skill_names,
restore_activated_skills,
)
ap = _make_ap()
ap.skill_mgr = SimpleNamespace(
skills={
'visible': _make_skill_data(name='visible'),
'hidden': _make_skill_data(name='hidden'),
}
)
query = SimpleNamespace(variables={PIPELINE_BOUND_SKILLS_KEY: ['visible']})
restored = restore_activated_skills(ap, query, ['visible', 'hidden', 'visible', ''])
assert restored == ['visible']
assert list(query.variables[ACTIVATED_SKILLS_KEY].keys()) == ['visible']
assert get_activated_skill_names(query) == ['visible']
def test_resolve_virtual_skill_path_allows_visible_skill_reads(self):
from langbot.pkg.provider.tools.loaders.skill import (
PIPELINE_BOUND_SKILLS_KEY,
@@ -268,8 +245,7 @@ class TestSkillPathHelpers:
command = wrap_skill_command_with_python_env('python scripts/run.py')
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
assert 'python -m venv "$_LB_VENV_DIR"' in command
assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command
assert command.rstrip().endswith('python scripts/run.py')
@@ -305,7 +281,6 @@ class TestSkillToolLoader:
assert result['activated'] is True
assert result['skill_name'] == 'demo'
assert result['mount_path'] == '/workspace/.skills/demo'
assert result['activated_skill_names'] == ['demo']
assert 'Step 1' in result['content']
assert set(query.variables[ACTIVATED_SKILLS_KEY].keys()) == {'demo'}
@@ -481,9 +456,7 @@ class TestNativeToolLoaderSkillPaths:
SimpleNamespace(query_id='q1', variables={PIPELINE_BOUND_SKILLS_KEY: ['demo']}),
)
assert result['ok'] is True
assert result['content'] == 'demo instructions'
assert result['truncated'] is False
assert result == {'ok': True, 'content': 'demo instructions'}
@pytest.mark.asyncio
async def test_exec_in_activated_skill_mount_rewrites_command_and_refreshes(self):
@@ -512,7 +485,7 @@ class TestNativeToolLoaderSkillPaths:
query,
)
assert result['ok'] is True
assert result == {'ok': True}
tool_parameters = ap.box_service.execute_tool.await_args.args[0]
assert tool_parameters['command'] == 'python /workspace/.skills/demo/scripts/run.py'
assert tool_parameters['workdir'] == '/workspace/.skills/demo'

View File

@@ -136,6 +136,7 @@ class TestToolManagerSchemaGeneration:
assert 'description' in func
assert 'parameters' in func
class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method."""

View File

@@ -248,135 +248,3 @@ async def test_path_escape_blocked():
with pytest.raises(ValueError, match='escapes'):
await loader.invoke_tool('read', {'path': '/workspace/../../etc/passwd'}, _make_query())
@pytest.mark.asyncio
async def test_box_availability_helper_handles_unavailable_and_errors():
from langbot.pkg.provider.tools.loaders.availability import is_box_backend_available
assert await is_box_backend_available(SimpleNamespace()) is False
assert await is_box_backend_available(SimpleNamespace(box_service=SimpleNamespace(available=False))) is False
unavailable_backend = SimpleNamespace(
available=True,
get_status=AsyncMock(return_value={'backend': {'available': False}}),
)
assert await is_box_backend_available(SimpleNamespace(box_service=unavailable_backend)) is False
failing_backend = SimpleNamespace(
available=True,
get_status=AsyncMock(side_effect=RuntimeError('box unavailable')),
)
assert await is_box_backend_available(SimpleNamespace(box_service=failing_backend)) is False
@pytest.mark.asyncio
async def test_read_file_supports_offset_limit_and_truncation_metadata():
with tempfile.TemporaryDirectory() as tmpdir:
loader, _ = _make_loader_with_workspace(tmpdir)
with open(os.path.join(tmpdir, 'lines.txt'), 'w', encoding='utf-8') as f:
f.write('one\ntwo\nthree\nfour\n')
result = await loader.invoke_tool(
'read',
{'path': '/workspace/lines.txt', 'offset': 2, 'limit': 2},
_make_query(),
)
assert result == {
'ok': True,
'content': 'two\nthree',
'truncated': True,
'truncated_by': 'lines',
'start_line': 2,
'end_line': 3,
'next_offset': 4,
'max_lines': 2,
'max_bytes': 50 * 1024,
}
@pytest.mark.asyncio
async def test_read_file_handles_line_larger_than_byte_limit():
with tempfile.TemporaryDirectory() as tmpdir:
loader, _ = _make_loader_with_workspace(tmpdir)
with open(os.path.join(tmpdir, 'long-line.txt'), 'w', encoding='utf-8') as f:
f.write('abcdef\n')
result = await loader.invoke_tool(
'read',
{'path': '/workspace/long-line.txt', 'max_bytes': 3},
_make_query(),
)
assert result['ok'] is True
assert result['truncated'] is True
assert result['truncated_by'] == 'bytes'
assert result['next_offset'] == 1
assert 'exceeds the 3B read limit' in result['content']
@pytest.mark.asyncio
async def test_exec_result_is_capped_and_exposes_preview_metadata():
with tempfile.TemporaryDirectory() as tmpdir:
box_service = SimpleNamespace(
available=True,
default_workspace=tmpdir,
execute_tool=AsyncMock(
return_value={
'ok': True,
'stdout': 'a' * 60000,
'stderr': 'b' * 60000,
'exit_code': 0,
}
),
)
loader = NativeToolLoader(SimpleNamespace(box_service=box_service, logger=Mock()))
result = await loader.invoke_tool('exec', {'command': 'python -V'}, _make_query())
assert result['ok'] is True
assert len(result['stdout'].encode('utf-8')) == 50 * 1024
assert len(result['stderr'].encode('utf-8')) == 50 * 1024
assert len(result['preview'].encode('utf-8')) == 50 * 1024
assert result['stdout_truncated'] is True
assert result['stderr_truncated'] is True
assert result['truncated'] is True
assert result['truncated_by'] == 'bytes'
@pytest.mark.asyncio
async def test_glob_caps_match_count_and_returns_preview():
with tempfile.TemporaryDirectory() as tmpdir:
loader, _ = _make_loader_with_workspace(tmpdir)
for index in range(105):
with open(os.path.join(tmpdir, f'file-{index:03d}.txt'), 'w', encoding='utf-8') as f:
f.write(str(index))
result = await loader.invoke_tool('glob', {'path': '/workspace', 'pattern': '*.txt'}, _make_query())
assert result['ok'] is True
assert result['total'] == 105
assert len(result['matches']) == 100
assert result['preview'] == '\n'.join(result['matches'])
assert result['truncated'] is True
assert result['truncated_by'] == 'matches'
@pytest.mark.asyncio
async def test_grep_reports_invalid_regex_and_truncates_long_matching_lines():
with tempfile.TemporaryDirectory() as tmpdir:
loader, _ = _make_loader_with_workspace(tmpdir)
with open(os.path.join(tmpdir, 'data.txt'), 'w', encoding='utf-8') as f:
f.write('needle ' + ('x' * 600) + '\n')
invalid = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': '['}, _make_query())
result = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': 'needle'}, _make_query())
assert invalid['ok'] is False
assert 'Invalid regex' in invalid['error']
assert result['ok'] is True
assert result['truncated'] is True
assert result['truncated_by'] == 'line'
assert result['matches'][0]['file'] == '/workspace/data.txt'
assert result['matches'][0]['content'].endswith('... [truncated]')

View File

@@ -3,6 +3,7 @@
Tests cover:
- _to_i18n_name() static method
"""
from __future__ import annotations
from importlib import import_module
@@ -60,4 +61,4 @@ class TestToI18nName:
kbmgr = get_kbmgr_module()
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
result = kbmgr.RAGManager._to_i18n_name(input_dict)
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}

Some files were not shown because too many files have changed in this diff Show More