mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-16 18:56:02 +00:00
Compare commits
12 Commits
codex/agen
...
feat/host-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8789c42eeb | ||
|
|
b3c6de2072 | ||
|
|
4e45886647 | ||
|
|
f592656680 | ||
|
|
e9db858dcc | ||
|
|
2d6faf9d5e | ||
|
|
d4699547e9 | ||
|
|
716d7aca94 | ||
|
|
b3c00fe6da | ||
|
|
f4a6edf7ec | ||
|
|
f390980d0a | ||
|
|
1ae5aacc00 |
46
.github/workflows/frontend-tests.yml
vendored
Normal file
46
.github/workflows/frontend-tests.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Frontend Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- 'web/**'
|
||||
- '.github/workflows/frontend-tests.yml'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
paths:
|
||||
- 'web/**'
|
||||
- '.github/workflows/frontend-tests.yml'
|
||||
|
||||
jobs:
|
||||
playwright-smoke:
|
||||
name: Playwright Smoke
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '25'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8.9.2
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Install Playwright browsers
|
||||
working-directory: web
|
||||
run: pnpm exec playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright smoke tests
|
||||
working-directory: web
|
||||
run: pnpm test:e2e
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run ruff check
|
||||
run: uv run ruff check src
|
||||
run: uv run ruff check src/langbot/ tests/ --output-format=concise
|
||||
|
||||
- name: Run ruff format
|
||||
run: uv run ruff format src --check
|
||||
|
||||
63
.github/workflows/run-tests.yml
vendored
63
.github/workflows/run-tests.yml
vendored
@@ -84,6 +84,67 @@ jobs:
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
e2e:
|
||||
name: E2E Startup Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run E2E startup tests
|
||||
run: uv run pytest tests/e2e -q --tb=short
|
||||
|
||||
- name: E2E Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## E2E Startup Test Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
box-integration:
|
||||
name: Box Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Check Docker runtime
|
||||
run: docker info
|
||||
|
||||
- name: Run Box integration tests
|
||||
run: uv run pytest tests/integration_tests -q --tb=short
|
||||
|
||||
- name: Box Integration Test Summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "## Box Integration Test Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
coverage:
|
||||
name: Coverage Gate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -129,4 +190,4 @@ jobs:
|
||||
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -313,18 +313,30 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Get traces
|
||||
traces, traces_total = await self.ap.monitoring_service.get_traces(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'overview': overview,
|
||||
'messages': messages,
|
||||
'llmCalls': llm_calls,
|
||||
'embeddingCalls': embedding_calls,
|
||||
'traces': traces,
|
||||
'sessions': sessions,
|
||||
'errors': errors,
|
||||
'totalCount': {
|
||||
'messages': messages_total,
|
||||
'llmCalls': llm_calls_total,
|
||||
'embeddingCalls': embedding_calls_total,
|
||||
'traces': traces_total,
|
||||
'sessions': sessions_total,
|
||||
'errors': errors_total,
|
||||
},
|
||||
@@ -350,6 +362,49 @@ class MonitoringRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=details)
|
||||
|
||||
@self.route('/traces', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_traces() -> str:
|
||||
"""Get end-to-end trace records."""
|
||||
bot_ids = quart.request.args.getlist('botId')
|
||||
pipeline_ids = quart.request.args.getlist('pipelineId')
|
||||
session_ids = quart.request.args.getlist('sessionId')
|
||||
statuses = quart.request.args.getlist('status')
|
||||
start_time_str = quart.request.args.get('startTime')
|
||||
end_time_str = quart.request.args.get('endTime')
|
||||
limit = int(quart.request.args.get('limit', 100))
|
||||
offset = int(quart.request.args.get('offset', 0))
|
||||
|
||||
start_time = parse_iso_datetime(start_time_str)
|
||||
end_time = parse_iso_datetime(end_time_str)
|
||||
|
||||
traces, total = await self.ap.monitoring_service.get_traces(
|
||||
bot_ids=bot_ids if bot_ids else None,
|
||||
pipeline_ids=pipeline_ids if pipeline_ids else None,
|
||||
session_ids=session_ids if session_ids else None,
|
||||
statuses=statuses if statuses else None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'traces': traces,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/traces/<trace_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def get_trace_details(trace_id: str) -> str:
|
||||
"""Get one trace with all spans."""
|
||||
details = await self.ap.monitoring_service.get_trace_details(trace_id)
|
||||
if not details.get('found'):
|
||||
return self.http_status(404, -1, f'Trace {trace_id} not found')
|
||||
return self.success(data=details)
|
||||
|
||||
@self.route('/export', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def export_data() -> tuple[str, int]:
|
||||
"""Export monitoring data as CSV"""
|
||||
|
||||
@@ -350,8 +350,24 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
if not endpoint.startswith('/') or '..' in endpoint:
|
||||
return self.http_status(400, -1, 'invalid endpoint')
|
||||
|
||||
caller = {
|
||||
'plugin_author': author,
|
||||
'plugin_name': plugin_name,
|
||||
'page_id': page_id,
|
||||
'origin': _get_request_origin(),
|
||||
}
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in {
|
||||
'user-agent': quart.request.headers.get('User-Agent'),
|
||||
'x-request-id': quart.request.headers.get('X-Request-ID'),
|
||||
'x-forwarded-for': quart.request.headers.get('X-Forwarded-For'),
|
||||
}.items()
|
||||
if value
|
||||
}
|
||||
|
||||
result = await self.ap.plugin_connector.handle_page_api(
|
||||
author, plugin_name, page_id, endpoint, method.upper(), body
|
||||
author, plugin_name, page_id, endpoint, method.upper(), body, caller, headers
|
||||
)
|
||||
if result.get('error'):
|
||||
return self.http_status(400, -1, result['error'])
|
||||
|
||||
@@ -3,11 +3,53 @@ from __future__ import annotations
|
||||
import uuid
|
||||
import datetime
|
||||
import sqlalchemy
|
||||
import json
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import monitoring as persistence_monitoring
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
def _json_dumps(value: dict | list | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False, default=str)
|
||||
except Exception:
|
||||
return json.dumps({'serialization_error': str(value)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _json_loads(value: str | None) -> dict | list | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_trace_id() -> str:
|
||||
return f'trace-{uuid.uuid4().hex[:16]}'
|
||||
|
||||
|
||||
def new_span_id() -> str:
|
||||
return f'span-{uuid.uuid4().hex[:16]}'
|
||||
|
||||
|
||||
def normalize_trace_status(status: str | None) -> str:
|
||||
"""Normalize operation status to the monitoring UI vocabulary."""
|
||||
if status in ('completed', 'ok'):
|
||||
return 'success'
|
||||
if status in ('failed', 'failure', 'exception'):
|
||||
return 'error'
|
||||
if status in ('running', 'success', 'error'):
|
||||
return status
|
||||
return 'success'
|
||||
|
||||
|
||||
class MonitoringService:
|
||||
"""Monitoring service"""
|
||||
|
||||
@@ -74,6 +116,18 @@ class MonitoringService:
|
||||
persistence_monitoring.MonitoringFeedback.timestamp,
|
||||
persistence_monitoring.MonitoringFeedback.id,
|
||||
),
|
||||
(
|
||||
'monitoring_traces',
|
||||
persistence_monitoring.MonitoringTrace,
|
||||
persistence_monitoring.MonitoringTrace.started_at,
|
||||
persistence_monitoring.MonitoringTrace.trace_id,
|
||||
),
|
||||
(
|
||||
'monitoring_spans',
|
||||
persistence_monitoring.MonitoringSpan,
|
||||
persistence_monitoring.MonitoringSpan.started_at,
|
||||
persistence_monitoring.MonitoringSpan.span_id,
|
||||
),
|
||||
]
|
||||
|
||||
deleted_counts: dict[str, int] = {}
|
||||
@@ -133,6 +187,116 @@ class MonitoringService:
|
||||
|
||||
# ========== Recording Methods ==========
|
||||
|
||||
async def start_trace(
|
||||
self,
|
||||
trace_id: str | None = None,
|
||||
name: str = 'LangBot query',
|
||||
bot_id: str | None = None,
|
||||
bot_name: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
pipeline_name: str | None = None,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
query_id: str | int | None = None,
|
||||
attributes: dict | None = None,
|
||||
) -> str:
|
||||
"""Create or update a trace header row."""
|
||||
trace_id = trace_id or new_trace_id()
|
||||
trace_data = {
|
||||
'trace_id': trace_id,
|
||||
'started_at': _utc_now(),
|
||||
'ended_at': None,
|
||||
'duration': None,
|
||||
'status': 'running',
|
||||
'name': name,
|
||||
'bot_id': bot_id,
|
||||
'bot_name': bot_name,
|
||||
'pipeline_id': pipeline_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'session_id': session_id,
|
||||
'message_id': message_id,
|
||||
'query_id': str(query_id) if query_id is not None else None,
|
||||
'attributes': _json_dumps(attributes),
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringTrace).values(trace_data)
|
||||
)
|
||||
return trace_id
|
||||
|
||||
async def finish_trace(
|
||||
self,
|
||||
trace_id: str,
|
||||
status: str = 'success',
|
||||
duration: int | None = None,
|
||||
message_id: str | None = None,
|
||||
attributes: dict | None = None,
|
||||
) -> None:
|
||||
"""Mark a trace complete."""
|
||||
update_values: dict = {
|
||||
'ended_at': _utc_now(),
|
||||
'status': normalize_trace_status(status),
|
||||
}
|
||||
if duration is not None:
|
||||
update_values['duration'] = duration
|
||||
if message_id is not None:
|
||||
update_values['message_id'] = message_id
|
||||
if attributes is not None:
|
||||
update_values['attributes'] = _json_dumps(attributes)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_monitoring.MonitoringTrace)
|
||||
.where(persistence_monitoring.MonitoringTrace.trace_id == trace_id)
|
||||
.values(update_values)
|
||||
)
|
||||
|
||||
async def record_span(
|
||||
self,
|
||||
trace_id: str,
|
||||
name: str,
|
||||
kind: str,
|
||||
status: str = 'success',
|
||||
span_id: str | None = None,
|
||||
parent_span_id: str | None = None,
|
||||
started_at: datetime.datetime | None = None,
|
||||
ended_at: datetime.datetime | None = None,
|
||||
duration: int | None = None,
|
||||
message_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
attributes: dict | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> str:
|
||||
"""Record a single completed span."""
|
||||
started_at = started_at or _utc_now()
|
||||
if duration is None and ended_at is not None:
|
||||
duration = int((ended_at - started_at).total_seconds() * 1000)
|
||||
elif duration is not None:
|
||||
duration = int(round(float(duration)))
|
||||
span_data = {
|
||||
'span_id': span_id or new_span_id(),
|
||||
'trace_id': trace_id,
|
||||
'parent_span_id': parent_span_id,
|
||||
'name': name,
|
||||
'kind': kind,
|
||||
'status': normalize_trace_status(status),
|
||||
'started_at': started_at,
|
||||
'ended_at': ended_at or _utc_now(),
|
||||
'duration': duration,
|
||||
'message_id': message_id,
|
||||
'session_id': session_id,
|
||||
'bot_id': bot_id,
|
||||
'pipeline_id': pipeline_id,
|
||||
'attributes': _json_dumps(attributes),
|
||||
'error_message': error_message,
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_monitoring.MonitoringSpan).values(span_data)
|
||||
)
|
||||
return span_data['span_id']
|
||||
|
||||
async def record_message(
|
||||
self,
|
||||
bot_id: str,
|
||||
@@ -1076,6 +1240,19 @@ class MonitoringService:
|
||||
for row in error_rows
|
||||
]
|
||||
|
||||
trace_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringTrace)
|
||||
.where(persistence_monitoring.MonitoringTrace.message_id == message_id)
|
||||
.order_by(persistence_monitoring.MonitoringTrace.started_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
trace_result = await self.ap.persistence_mgr.execute_async(trace_query)
|
||||
trace_row = trace_result.first()
|
||||
trace = None
|
||||
if trace_row:
|
||||
trace_model = trace_row[0] if isinstance(trace_row, tuple) else trace_row
|
||||
trace = self._serialize_trace(trace_model)
|
||||
|
||||
return {
|
||||
'message_id': message_id,
|
||||
'found': True,
|
||||
@@ -1090,6 +1267,90 @@ class MonitoringService:
|
||||
'average_duration_ms': int(total_duration / len(llm_rows)) if len(llm_rows) > 0 else 0,
|
||||
},
|
||||
'errors': errors,
|
||||
'trace': trace,
|
||||
}
|
||||
|
||||
def _serialize_trace(self, trace: persistence_monitoring.MonitoringTrace) -> dict:
|
||||
data = self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringTrace, trace)
|
||||
data['attributes'] = _json_loads(data.get('attributes')) or {}
|
||||
return data
|
||||
|
||||
def _serialize_span(self, span: persistence_monitoring.MonitoringSpan) -> dict:
|
||||
data = self.ap.persistence_mgr.serialize_model(persistence_monitoring.MonitoringSpan, span)
|
||||
data['attributes'] = _json_loads(data.get('attributes')) or {}
|
||||
return data
|
||||
|
||||
async def get_traces(
|
||||
self,
|
||||
bot_ids: list[str] | None = None,
|
||||
pipeline_ids: list[str] | None = None,
|
||||
session_ids: list[str] | None = None,
|
||||
statuses: list[str] | None = None,
|
||||
start_time: datetime.datetime | None = None,
|
||||
end_time: datetime.datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get trace headers with filters."""
|
||||
conditions = []
|
||||
if bot_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.bot_id.in_(bot_ids))
|
||||
if pipeline_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.pipeline_id.in_(pipeline_ids))
|
||||
if session_ids:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.session_id.in_(session_ids))
|
||||
if statuses:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.status.in_(statuses))
|
||||
if start_time:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.started_at >= start_time)
|
||||
if end_time:
|
||||
conditions.append(persistence_monitoring.MonitoringTrace.started_at <= end_time)
|
||||
|
||||
count_query = sqlalchemy.select(sqlalchemy.func.count(persistence_monitoring.MonitoringTrace.trace_id))
|
||||
query = sqlalchemy.select(persistence_monitoring.MonitoringTrace)
|
||||
if conditions:
|
||||
clause = sqlalchemy.and_(*conditions)
|
||||
count_query = count_query.where(clause)
|
||||
query = query.where(clause)
|
||||
|
||||
total_result = await self.ap.persistence_mgr.execute_async(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
query = query.order_by(persistence_monitoring.MonitoringTrace.started_at.desc()).limit(limit).offset(offset)
|
||||
result = await self.ap.persistence_mgr.execute_async(query)
|
||||
traces = [
|
||||
self._serialize_trace(row[0] if isinstance(row, tuple) else row)
|
||||
for row in result.all()
|
||||
]
|
||||
return traces, total
|
||||
|
||||
async def get_trace_details(self, trace_id: str) -> dict:
|
||||
"""Get a single trace and all spans in chronological order."""
|
||||
trace_query = sqlalchemy.select(persistence_monitoring.MonitoringTrace).where(
|
||||
persistence_monitoring.MonitoringTrace.trace_id == trace_id
|
||||
)
|
||||
trace_result = await self.ap.persistence_mgr.execute_async(trace_query)
|
||||
trace_row = trace_result.first()
|
||||
if not trace_row:
|
||||
return {'trace_id': trace_id, 'found': False}
|
||||
|
||||
trace = trace_row[0] if isinstance(trace_row, tuple) else trace_row
|
||||
span_query = (
|
||||
sqlalchemy.select(persistence_monitoring.MonitoringSpan)
|
||||
.where(persistence_monitoring.MonitoringSpan.trace_id == trace_id)
|
||||
.order_by(persistence_monitoring.MonitoringSpan.started_at.asc())
|
||||
)
|
||||
span_result = await self.ap.persistence_mgr.execute_async(span_query)
|
||||
spans = [
|
||||
self._serialize_span(row[0] if isinstance(row, tuple) else row)
|
||||
for row in span_result.all()
|
||||
]
|
||||
|
||||
return {
|
||||
'trace_id': trace_id,
|
||||
'found': True,
|
||||
'trace': self._serialize_trace(trace),
|
||||
'spans': spans,
|
||||
}
|
||||
|
||||
# ========== Export Methods ==========
|
||||
|
||||
@@ -146,19 +146,13 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
|
||||
_LB_PIP_CACHE_DIR="{mount_path}/.cache/pip"
|
||||
|
||||
mkdir -p "$_LB_META_DIR" "$_LB_TMP_DIR" "$_LB_PIP_CACHE_DIR"
|
||||
_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"
|
||||
if [ -z "$_LB_SYSTEM_PYTHON" ]; then
|
||||
echo "python3 or python is required to prepare the workspace Python environment" >&2
|
||||
exit 127
|
||||
fi
|
||||
|
||||
export TMPDIR="$_LB_TMP_DIR"
|
||||
export TEMP="$_LB_TMP_DIR"
|
||||
export TMP="$_LB_TMP_DIR"
|
||||
export PIP_CACHE_DIR="$_LB_PIP_CACHE_DIR"
|
||||
|
||||
_lb_python_meta() {{
|
||||
"$_LB_SYSTEM_PYTHON" - <<'PY'
|
||||
python - <<'PY'
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
@@ -207,26 +201,15 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
|
||||
_LB_LOCK_WAIT=0
|
||||
while ! mkdir "$_LB_LOCK_DIR" 2>/dev/null; do
|
||||
if [ "$_LB_LOCK_WAIT" -ge 120 ]; then
|
||||
_LB_LOCK_OWNER="$(cat "$_LB_LOCK_DIR/pid" 2>/dev/null || true)"
|
||||
if [ -n "$_LB_LOCK_OWNER" ] && kill -0 "$_LB_LOCK_OWNER" 2>/dev/null; then
|
||||
echo "Timed out waiting for active Python environment lock: $_LB_LOCK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "Timed out waiting for Python environment lock, clearing stale lock: $_LB_LOCK_DIR" >&2
|
||||
rm -rf "$_LB_LOCK_DIR" 2>/dev/null || true
|
||||
if mkdir "$_LB_LOCK_DIR" 2>/dev/null; then
|
||||
break
|
||||
fi
|
||||
echo "Timed out waiting for Python environment lock: $_LB_LOCK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
sleep 1
|
||||
_LB_LOCK_WAIT=$((_LB_LOCK_WAIT + 1))
|
||||
done
|
||||
printf '%s\\n' "$$" > "$_LB_LOCK_DIR/pid" 2>/dev/null || true
|
||||
|
||||
_lb_cleanup_lock() {{
|
||||
rm -rf "$_LB_LOCK_DIR" >/dev/null 2>&1 || true
|
||||
rmdir "$_LB_LOCK_DIR" >/dev/null 2>&1 || true
|
||||
}}
|
||||
trap _lb_cleanup_lock EXIT INT TERM
|
||||
|
||||
@@ -242,7 +225,7 @@ def wrap_python_command_with_env(command: str, *, mount_path: str = '/workspace'
|
||||
|
||||
if [ "$_LB_NEEDS_BOOTSTRAP" -eq 1 ]; then
|
||||
rm -rf "$_LB_VENV_DIR"
|
||||
"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"
|
||||
python -m venv "$_LB_VENV_DIR"
|
||||
. "$_LB_VENV_DIR/bin/activate"
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
if [ -f "{mount_path}/requirements.txt" ]; then
|
||||
|
||||
@@ -3,6 +3,49 @@ import sqlalchemy
|
||||
from .base import Base
|
||||
|
||||
|
||||
class MonitoringTrace(Base):
|
||||
"""End-to-end monitoring trace records"""
|
||||
|
||||
__tablename__ = 'monitoring_traces'
|
||||
|
||||
trace_id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
ended_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True)
|
||||
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) # milliseconds
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True) # running, success, error
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
bot_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
query_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
attributes = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
|
||||
|
||||
class MonitoringSpan(Base):
|
||||
"""Trace span records for pipeline, RAG, model, plugin and tool operations"""
|
||||
|
||||
__tablename__ = 'monitoring_spans'
|
||||
|
||||
span_id = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
trace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
parent_span_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
kind = sqlalchemy.Column(sqlalchemy.String(80), nullable=False, index=True)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True)
|
||||
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, index=True)
|
||||
ended_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
duration = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) # milliseconds
|
||||
message_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
session_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
pipeline_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
attributes = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
error_message = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
|
||||
|
||||
class MonitoringMessage(Base):
|
||||
"""Monitoring message records"""
|
||||
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
"""add monitoring traces and spans
|
||||
|
||||
Revision ID: 0006_monitoring_traces
|
||||
Revises: 0005_add_llm_context_length
|
||||
Create Date: 2026-06-16
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '0006_monitoring_traces'
|
||||
down_revision = '0005_add_llm_context_length'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
tables = set(inspector.get_table_names())
|
||||
|
||||
if 'monitoring_traces' not in tables:
|
||||
op.create_table(
|
||||
'monitoring_traces',
|
||||
sa.Column('trace_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('ended_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('duration', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('bot_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('bot_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('pipeline_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('pipeline_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('session_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('message_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('query_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('attributes', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('trace_id'),
|
||||
)
|
||||
op.create_index('ix_monitoring_traces_started_at', 'monitoring_traces', ['started_at'])
|
||||
op.create_index('ix_monitoring_traces_ended_at', 'monitoring_traces', ['ended_at'])
|
||||
op.create_index('ix_monitoring_traces_status', 'monitoring_traces', ['status'])
|
||||
op.create_index('ix_monitoring_traces_bot_id', 'monitoring_traces', ['bot_id'])
|
||||
op.create_index('ix_monitoring_traces_pipeline_id', 'monitoring_traces', ['pipeline_id'])
|
||||
op.create_index('ix_monitoring_traces_session_id', 'monitoring_traces', ['session_id'])
|
||||
op.create_index('ix_monitoring_traces_message_id', 'monitoring_traces', ['message_id'])
|
||||
op.create_index('ix_monitoring_traces_query_id', 'monitoring_traces', ['query_id'])
|
||||
|
||||
if 'monitoring_spans' not in tables:
|
||||
op.create_table(
|
||||
'monitoring_spans',
|
||||
sa.Column('span_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('trace_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('parent_span_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('kind', sa.String(length=80), nullable=False),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('ended_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('duration', sa.Integer(), nullable=True),
|
||||
sa.Column('message_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('session_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('bot_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('pipeline_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('attributes', sa.Text(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('span_id'),
|
||||
)
|
||||
op.create_index('ix_monitoring_spans_trace_id', 'monitoring_spans', ['trace_id'])
|
||||
op.create_index('ix_monitoring_spans_parent_span_id', 'monitoring_spans', ['parent_span_id'])
|
||||
op.create_index('ix_monitoring_spans_kind', 'monitoring_spans', ['kind'])
|
||||
op.create_index('ix_monitoring_spans_status', 'monitoring_spans', ['status'])
|
||||
op.create_index('ix_monitoring_spans_started_at', 'monitoring_spans', ['started_at'])
|
||||
op.create_index('ix_monitoring_spans_message_id', 'monitoring_spans', ['message_id'])
|
||||
op.create_index('ix_monitoring_spans_session_id', 'monitoring_spans', ['session_id'])
|
||||
op.create_index('ix_monitoring_spans_bot_id', 'monitoring_spans', ['bot_id'])
|
||||
op.create_index('ix_monitoring_spans_pipeline_id', 'monitoring_spans', ['pipeline_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
tables = set(inspector.get_table_names())
|
||||
if 'monitoring_spans' in tables:
|
||||
op.drop_table('monitoring_spans')
|
||||
if 'monitoring_traces' in tables:
|
||||
op.drop_table('monitoring_traces')
|
||||
@@ -2,6 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
import time
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
@@ -79,6 +82,19 @@ class RuntimePipeline:
|
||||
enable_all_plugins: bool
|
||||
"""是否启用所有插件"""
|
||||
|
||||
@staticmethod
|
||||
def _new_span_id() -> str:
|
||||
return f'span-{uuid.uuid4().hex[:16]}'
|
||||
|
||||
@staticmethod
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
@staticmethod
|
||||
def _query_session_id(query: pipeline_query.Query) -> str:
|
||||
launcher_type = query.launcher_type.value if hasattr(query.launcher_type, 'value') else str(query.launcher_type)
|
||||
return f'{launcher_type}_{query.launcher_id}'
|
||||
|
||||
enable_all_mcp_servers: bool
|
||||
"""是否启用所有MCP服务器"""
|
||||
|
||||
@@ -234,44 +250,92 @@ class RuntimePipeline:
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
|
||||
span_started_at = self._utc_now()
|
||||
span_started = time.perf_counter()
|
||||
span_status = 'success'
|
||||
span_error = None
|
||||
span_result_type = None
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
try:
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
|
||||
)
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
|
||||
|
||||
async for sub_result in result:
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
span_result_type = str(result.result_type.value if hasattr(result.result_type, 'value') else result.result_type)
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
|
||||
)
|
||||
await self._check_output(query, sub_result)
|
||||
await self._check_output(query, result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
span_result_type = 'generator'
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
|
||||
|
||||
async for sub_result in result:
|
||||
span_result_type = str(
|
||||
sub_result.result_type.value
|
||||
if hasattr(sub_result.result_type, 'value')
|
||||
else sub_result.result_type
|
||||
)
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
|
||||
)
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
except Exception as e:
|
||||
span_status = 'error'
|
||||
span_error = str(e)
|
||||
raise
|
||||
finally:
|
||||
trace_id = (query.variables or {}).get('_monitoring_trace_id')
|
||||
root_span_id = (query.variables or {}).get('_monitoring_root_span_id')
|
||||
if trace_id:
|
||||
try:
|
||||
await self.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=root_span_id,
|
||||
name=stage_container.inst_name,
|
||||
kind='pipeline.stage',
|
||||
status=span_status,
|
||||
started_at=span_started_at,
|
||||
duration=int((time.perf_counter() - span_started) * 1000),
|
||||
message_id=(query.variables or {}).get('_monitoring_message_id'),
|
||||
session_id=self._query_session_id(query),
|
||||
bot_id=query.bot_uuid,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
attributes={
|
||||
'stage_class': stage_container.inst.__class__.__name__,
|
||||
'result_type': span_result_type,
|
||||
'query_id': query.query_id,
|
||||
},
|
||||
error_message=span_error,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.ap.logger.error(f'Failed to record stage span: {monitor_err}')
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: pipeline_query.Query):
|
||||
"""处理请求"""
|
||||
trace_started_at = self._utc_now()
|
||||
trace_started = time.perf_counter()
|
||||
root_span_id = self._new_span_id()
|
||||
trace_id = None
|
||||
trace_status = 'success'
|
||||
# Get monitoring metadata
|
||||
bot_name = query.variables.get('_monitoring_bot_name', 'Unknown')
|
||||
pipeline_name = query.variables.get('_monitoring_pipeline_name', 'Unknown')
|
||||
@@ -303,6 +367,28 @@ class RuntimePipeline:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record query start: {e}')
|
||||
|
||||
try:
|
||||
trace_id = await self.ap.monitoring_service.start_trace(
|
||||
name='LangBot query',
|
||||
bot_id=query.bot_uuid or 'unknown',
|
||||
bot_name=bot_name,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
pipeline_name=pipeline_name,
|
||||
session_id=self._query_session_id(query),
|
||||
message_id=message_id or None,
|
||||
query_id=query.query_id,
|
||||
attributes={
|
||||
'launcher_type': query.launcher_type.value
|
||||
if hasattr(query.launcher_type, 'value')
|
||||
else str(query.launcher_type),
|
||||
'runner_name': runner_name,
|
||||
},
|
||||
)
|
||||
query.variables['_monitoring_trace_id'] = trace_id
|
||||
query.variables['_monitoring_root_span_id'] = root_span_id
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to start query trace: {e}')
|
||||
|
||||
try:
|
||||
# Get bound plugins for this pipeline
|
||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
@@ -361,6 +447,7 @@ class RuntimePipeline:
|
||||
self.ap.logger.error(f'Failed to record query response: {e}')
|
||||
|
||||
except Exception as e:
|
||||
trace_status = 'error'
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
@@ -383,6 +470,35 @@ class RuntimePipeline:
|
||||
self.ap.logger.error(f'Failed to record query error: {me}')
|
||||
|
||||
finally:
|
||||
if trace_id:
|
||||
try:
|
||||
duration_ms = int((time.perf_counter() - trace_started) * 1000)
|
||||
await self.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
span_id=root_span_id,
|
||||
name='LangBot query',
|
||||
kind='pipeline.query',
|
||||
status=trace_status,
|
||||
started_at=trace_started_at,
|
||||
duration=duration_ms,
|
||||
message_id=message_id or None,
|
||||
session_id=self._query_session_id(query),
|
||||
bot_id=query.bot_uuid,
|
||||
pipeline_id=self.pipeline_entity.uuid,
|
||||
attributes={
|
||||
'query_id': query.query_id,
|
||||
'pipeline_name': pipeline_name,
|
||||
'runner_name': runner_name,
|
||||
},
|
||||
)
|
||||
await self.ap.monitoring_service.finish_trace(
|
||||
trace_id=trace_id,
|
||||
status=trace_status,
|
||||
duration=duration_ms,
|
||||
message_id=message_id or None,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.ap.logger.error(f'Failed to finish query trace: {monitor_err}')
|
||||
self.ap.logger.debug(f'Query {query.query_id} processed')
|
||||
del self.ap.query_pool.cached_queries[query.query_id]
|
||||
|
||||
|
||||
@@ -711,8 +711,19 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
|
||||
endpoint: str,
|
||||
method: str,
|
||||
body: Any = None,
|
||||
caller: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await self.handler.handle_page_api(plugin_author, plugin_name, page_id, endpoint, method, body)
|
||||
return await self.handler.handle_page_api(
|
||||
plugin_author,
|
||||
plugin_name,
|
||||
page_id,
|
||||
endpoint,
|
||||
method,
|
||||
body,
|
||||
caller,
|
||||
headers or {},
|
||||
)
|
||||
|
||||
async def get_debug_info(self) -> dict[str, Any]:
|
||||
"""Get debug information including debug key and WS URL"""
|
||||
|
||||
@@ -755,6 +755,19 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'session_name': session_name,
|
||||
'bot_uuid': query.bot_uuid or '',
|
||||
'sender_id': str(query.sender_id),
|
||||
'_trace_context': {
|
||||
'trace_id': query.variables.get('_monitoring_trace_id') if query.variables else None,
|
||||
'parent_span_id': query.variables.get('_monitoring_root_span_id') if query.variables else None,
|
||||
'message_id': query.variables.get('_monitoring_message_id') if query.variables else None,
|
||||
'query_id': query.query_id,
|
||||
'session_id': session_name,
|
||||
'bot_id': query.bot_uuid or '',
|
||||
'pipeline_id': query.pipeline_uuid or '',
|
||||
'knowledge_base_id': kb_id,
|
||||
'attributes': {
|
||||
'source': 'plugin-api',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
results = [entry.model_dump(mode='json') for entry in entries]
|
||||
@@ -1011,6 +1024,8 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
endpoint: str,
|
||||
method: str,
|
||||
body: Any = None,
|
||||
caller: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Forward a page API call to the plugin via runtime."""
|
||||
result = await self.call_action(
|
||||
@@ -1022,6 +1037,8 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'endpoint': endpoint,
|
||||
'method': method,
|
||||
'body': body,
|
||||
'caller': caller,
|
||||
'headers': headers or {},
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from ...core import app
|
||||
from ...entity.persistence import model as persistence_model
|
||||
@@ -16,6 +17,15 @@ LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
|
||||
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
def _query_session_id(query: pipeline_query.Query) -> str:
|
||||
launcher_type = query.launcher_type.value if hasattr(query.launcher_type, 'value') else str(query.launcher_type)
|
||||
return f'{launcher_type}_{query.launcher_id}'
|
||||
|
||||
|
||||
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
|
||||
"""Store the latest provider usage on the query for upstream action handlers."""
|
||||
if query is None or not usage_info:
|
||||
@@ -59,6 +69,7 @@ class RuntimeProvider:
|
||||
"""Bridge method for invoking LLM with monitoring"""
|
||||
# Start timing for monitoring
|
||||
start_time = time.time()
|
||||
span_started_at = _utc_now()
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
status = 'success'
|
||||
@@ -125,6 +136,30 @@ class RuntimeProvider:
|
||||
error_message=error_message,
|
||||
message_id=message_id,
|
||||
)
|
||||
trace_id = query.variables.get('_monitoring_trace_id') if query.variables else None
|
||||
parent_span_id = query.variables.get('_monitoring_root_span_id') if query.variables else None
|
||||
if trace_id:
|
||||
await self.requester.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
name=f'LLM {model.model_entity.name}',
|
||||
kind='model.llm',
|
||||
status=status,
|
||||
started_at=span_started_at,
|
||||
duration=duration_ms,
|
||||
message_id=message_id,
|
||||
session_id=_query_session_id(query),
|
||||
bot_id=query.bot_uuid,
|
||||
pipeline_id=query.pipeline_uuid,
|
||||
attributes={
|
||||
'model_name': model.model_entity.name,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': input_tokens + output_tokens,
|
||||
'stream': False,
|
||||
},
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM call: {monitor_err}')
|
||||
|
||||
@@ -140,6 +175,7 @@ class RuntimeProvider:
|
||||
"""Bridge method for invoking LLM stream with monitoring"""
|
||||
# Start timing for monitoring
|
||||
start_time = time.time()
|
||||
span_started_at = _utc_now()
|
||||
status = 'success'
|
||||
error_message = None
|
||||
input_tokens = 0
|
||||
@@ -204,6 +240,30 @@ class RuntimeProvider:
|
||||
error_message=error_message,
|
||||
message_id=message_id,
|
||||
)
|
||||
trace_id = query.variables.get('_monitoring_trace_id') if query.variables else None
|
||||
parent_span_id = query.variables.get('_monitoring_root_span_id') if query.variables else None
|
||||
if trace_id:
|
||||
await self.requester.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
name=f'LLM stream {model.model_entity.name}',
|
||||
kind='model.llm',
|
||||
status=status,
|
||||
started_at=span_started_at,
|
||||
duration=duration_ms,
|
||||
message_id=message_id,
|
||||
session_id=_query_session_id(query),
|
||||
bot_id=query.bot_uuid,
|
||||
pipeline_id=query.pipeline_uuid,
|
||||
attributes={
|
||||
'model_name': model.model_entity.name,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': input_tokens + output_tokens,
|
||||
'stream': True,
|
||||
},
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as monitor_err:
|
||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record LLM stream call: {monitor_err}')
|
||||
|
||||
|
||||
@@ -268,6 +268,19 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
'bot_uuid': query.bot_uuid or '',
|
||||
'sender_id': str(query.sender_id),
|
||||
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'_trace_context': {
|
||||
'trace_id': query.variables.get('_monitoring_trace_id') if query.variables else None,
|
||||
'parent_span_id': query.variables.get('_monitoring_root_span_id') if query.variables else None,
|
||||
'message_id': query.variables.get('_monitoring_message_id') if query.variables else None,
|
||||
'query_id': query.query_id,
|
||||
'session_id': f'{query.launcher_type.value}_{query.launcher_id}',
|
||||
'bot_id': query.bot_uuid or '',
|
||||
'pipeline_id': query.pipeline_uuid or '',
|
||||
'knowledge_base_id': kb_uuid,
|
||||
'attributes': {
|
||||
'source': 'local-agent',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
async def is_box_backend_available(ap: Any) -> bool:
|
||||
"""Return whether the configured Box backend is ready for tool execution."""
|
||||
box_service = getattr(ap, 'box_service', None)
|
||||
if box_service is None:
|
||||
return False
|
||||
if not getattr(box_service, 'available', False):
|
||||
return False
|
||||
try:
|
||||
status = await box_service.get_status()
|
||||
backend_info = status.get('backend', {})
|
||||
return bool(backend_info.get('available', False))
|
||||
except Exception:
|
||||
return False
|
||||
@@ -5,8 +5,6 @@ import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import shlex
|
||||
import threading
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pydantic
|
||||
@@ -20,26 +18,12 @@ from ....box.workspace import (
|
||||
rewrite_mounted_path,
|
||||
rewrite_venv_command,
|
||||
unwrap_venv_path,
|
||||
wrap_python_command_with_env,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mcp import RuntimeMCPSession
|
||||
|
||||
|
||||
_WORKSPACE_COPY_LOCKS: dict[str, threading.Lock] = {}
|
||||
_WORKSPACE_COPY_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _workspace_copy_lock(path: str) -> threading.Lock:
|
||||
with _WORKSPACE_COPY_LOCKS_GUARD:
|
||||
lock = _WORKSPACE_COPY_LOCKS.get(path)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_WORKSPACE_COPY_LOCKS[path] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class MCPSessionErrorPhase(enum.Enum):
|
||||
"""Which phase of the MCP lifecycle failed."""
|
||||
|
||||
@@ -65,7 +49,7 @@ class MCPServerBoxConfig(pydantic.BaseModel):
|
||||
host_path: str | None = None
|
||||
host_path_mode: str = 'ro' # MCP servers default to read-write mount only when explicitly requested
|
||||
env: dict[str, str] = pydantic.Field(default_factory=dict)
|
||||
startup_timeout_sec: int = 300 # First Docker bootstrap may need to build a venv and install MCP deps.
|
||||
startup_timeout_sec: int = 120 # Longer default to allow dependency bootstrap
|
||||
cpus: float | None = None
|
||||
memory_mb: int | None = None
|
||||
pids_limit: int | None = None
|
||||
@@ -144,7 +128,6 @@ class BoxStdioSessionRuntime:
|
||||
workspace = self._build_workspace(host_path=None)
|
||||
host_path = self.resolve_host_path()
|
||||
process_cwd = '/workspace'
|
||||
install_cmd: str | None = None
|
||||
|
||||
try:
|
||||
await workspace.create_session()
|
||||
@@ -185,8 +168,6 @@ class BoxStdioSessionRuntime:
|
||||
env=self.server_config.get('env', {}),
|
||||
cwd=process_cwd,
|
||||
)
|
||||
if install_cmd:
|
||||
payload = self._wrap_process_payload_with_python_env(payload, process_cwd)
|
||||
payload['process_id'] = self.process_id
|
||||
await workspace.box_service.start_managed_process(workspace.session_id, payload)
|
||||
except Exception:
|
||||
@@ -272,42 +253,14 @@ class BoxStdioSessionRuntime:
|
||||
|
||||
@staticmethod
|
||||
def _copy_workspace_tree(source_path: str, process_host_root: str, process_host_workspace: str) -> None:
|
||||
# Docker-backed bootstrap writes root-owned runtime directories such as
|
||||
# .venv/.tmp into the staged workspace. The host process may not be able
|
||||
# to delete them, so refresh source files in place and preserve runtime
|
||||
# directories instead of rmtree'ing the whole staging root.
|
||||
with _workspace_copy_lock(process_host_root):
|
||||
preserved_names = {'.venv', 'venv', 'env', '.cache', '.tmp', '.langbot'}
|
||||
os.makedirs(process_host_workspace, exist_ok=True)
|
||||
for name in os.listdir(process_host_workspace):
|
||||
if name in preserved_names:
|
||||
continue
|
||||
path = os.path.join(process_host_workspace, name)
|
||||
if os.path.isdir(path) and not os.path.islink(path):
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
else:
|
||||
# The entry may disappear between listdir and unlink if cleanup races us.
|
||||
with suppress(FileNotFoundError):
|
||||
os.unlink(path)
|
||||
shutil.copytree(
|
||||
source_path,
|
||||
process_host_workspace,
|
||||
symlinks=True,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns(
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'.pytest_cache',
|
||||
'.mypy_cache',
|
||||
'.ruff_cache',
|
||||
'.venv',
|
||||
'venv',
|
||||
'env',
|
||||
'.cache',
|
||||
'.tmp',
|
||||
'.langbot',
|
||||
),
|
||||
)
|
||||
shutil.rmtree(process_host_root, ignore_errors=True)
|
||||
os.makedirs(process_host_root, exist_ok=True)
|
||||
shutil.copytree(
|
||||
source_path,
|
||||
process_host_workspace,
|
||||
symlinks=True,
|
||||
ignore=shutil.ignore_patterns('.git', '__pycache__', '.pytest_cache', '.mypy_cache', '.ruff_cache'),
|
||||
)
|
||||
|
||||
async def _cleanup_staged_workspace(self) -> None:
|
||||
if not self.resolve_host_path():
|
||||
@@ -390,25 +343,23 @@ class BoxStdioSessionRuntime:
|
||||
@staticmethod
|
||||
def detect_install_command(host_path: str, workspace_path: str = '/workspace') -> str | None:
|
||||
workspace_kind = classify_python_workspace(host_path)
|
||||
if workspace_kind in {'package', 'requirements'}:
|
||||
return wrap_python_command_with_env('python -c "pass"', mount_path=workspace_path).rstrip()
|
||||
quoted_workspace_path = shlex.quote(workspace_path)
|
||||
if workspace_kind == 'package':
|
||||
return (
|
||||
'mkdir -p /opt/_lb_src'
|
||||
f' && tar -C {quoted_workspace_path}'
|
||||
' --exclude=.venv --exclude=.git --exclude=__pycache__'
|
||||
' --exclude=node_modules --exclude=.tox --exclude=.nox'
|
||||
' --exclude="*.egg-info" --exclude=.uv-cache'
|
||||
' -cf - .'
|
||||
' | tar -C /opt/_lb_src -xf -'
|
||||
' && pip install --no-cache-dir /opt/_lb_src'
|
||||
' && rm -rf /opt/_lb_src'
|
||||
)
|
||||
if workspace_kind == 'requirements':
|
||||
return f'pip install --no-cache-dir -r {quoted_workspace_path}/requirements.txt'
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _wrap_process_payload_with_python_env(payload: dict[str, Any], workspace_path: str) -> dict[str, Any]:
|
||||
"""Start a prepared Python workspace without writing bootstrap output to MCP stdio."""
|
||||
workspace_root = workspace_path.rstrip('/') or '/workspace'
|
||||
venv_dir = f'{workspace_root}/.venv'
|
||||
venv_bin = f'{venv_dir}/bin'
|
||||
command = ' '.join([shlex.quote(payload['command']), *[shlex.quote(arg) for arg in payload.get('args', [])]])
|
||||
wrapped = dict(payload)
|
||||
wrapped['command'] = 'sh'
|
||||
wrapped['args'] = [
|
||||
'-lc',
|
||||
(f'export VIRTUAL_ENV={shlex.quote(venv_dir)}; export PATH={shlex.quote(venv_bin)}:$PATH; exec {command}'),
|
||||
]
|
||||
return wrapped
|
||||
|
||||
def build_box_session_payload(self, session_id: str, host_path: str | None = None) -> dict[str, Any]:
|
||||
workspace = self._build_workspace()
|
||||
workspace.session_id = session_id
|
||||
|
||||
@@ -8,7 +8,6 @@ from langbot_plugin.api.entities.events import pipeline_query
|
||||
|
||||
from .. import loader
|
||||
from ..errors import ToolNotFoundError
|
||||
from .availability import is_box_backend_available
|
||||
from . import skill as skill_loader
|
||||
|
||||
EXEC_TOOL_NAME = 'exec'
|
||||
@@ -23,15 +22,6 @@ _ALL_TOOL_NAMES = {EXEC_TOOL_NAME, READ_TOOL_NAME, WRITE_TOOL_NAME, EDIT_TOOL_NA
|
||||
# Skip these dirs during grep walk to avoid noise
|
||||
_SKIP_DIRS = {'.git', 'node_modules', '__pycache__', '.venv', 'venv', '.tox', 'dist', 'build'}
|
||||
|
||||
_DEFAULT_READ_MAX_LINES = 2000
|
||||
_MAX_READ_MAX_LINES = 10000
|
||||
_DEFAULT_TOOL_RESULT_MAX_BYTES = 50 * 1024
|
||||
_BOX_FILE_SCRIPT_MAX_BYTES = 2048
|
||||
_GLOB_MAX_MATCHES = 100
|
||||
_GREP_MAX_MATCHES = 200
|
||||
_GREP_MAX_FILES = 5000
|
||||
_GREP_MAX_LINE_CHARS = 500
|
||||
|
||||
|
||||
class NativeToolLoader(loader.ToolLoader):
|
||||
def __init__(self, ap):
|
||||
@@ -53,7 +43,18 @@ class NativeToolLoader(loader.ToolLoader):
|
||||
|
||||
async def _check_backend_available(self) -> bool:
|
||||
"""Check if the box backend is truly available (not just the runtime)."""
|
||||
return await is_box_backend_available(self.ap)
|
||||
box_service = getattr(self.ap, 'box_service', None)
|
||||
if box_service is None:
|
||||
return False
|
||||
if not getattr(box_service, 'available', False):
|
||||
return False
|
||||
# Check if backend is truly available via get_status
|
||||
try:
|
||||
status = await box_service.get_status()
|
||||
backend_info = status.get('backend', {})
|
||||
return backend_info.get('available', False)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
|
||||
if not self._is_sandbox_available():
|
||||
@@ -138,7 +139,6 @@ class NativeToolLoader(loader.ToolLoader):
|
||||
# via execute_tool. Skills are mounted at /workspace/.skills/{name}/
|
||||
# via extra_mounts built by BoxService.
|
||||
result = await self.ap.box_service.execute_tool(parameters, query)
|
||||
result = self._normalize_exec_result(result)
|
||||
|
||||
if selected_skill is not None:
|
||||
self._refresh_skill_from_disk(selected_skill)
|
||||
@@ -227,65 +227,19 @@ class NativeToolLoader(loader.ToolLoader):
|
||||
except Exception:
|
||||
return {'ok': False, 'error': stdout or 'Box file operation returned no result'}
|
||||
|
||||
async def _read_workspace_via_box(self, path: str, parameters: dict, query: pipeline_query.Query) -> dict:
|
||||
offset = self._positive_int(parameters.get('offset'), default=1)
|
||||
max_lines = self._positive_int(
|
||||
parameters.get('limit'),
|
||||
default=_DEFAULT_READ_MAX_LINES,
|
||||
max_value=_MAX_READ_MAX_LINES,
|
||||
)
|
||||
# Box file fallback returns through exec stdout, which is already capped
|
||||
# by BoxService. Keep this payload small enough to remain valid JSON.
|
||||
max_bytes = min(
|
||||
self._positive_int(parameters.get('max_bytes'), default=_DEFAULT_TOOL_RESULT_MAX_BYTES),
|
||||
_BOX_FILE_SCRIPT_MAX_BYTES,
|
||||
)
|
||||
async def _read_workspace_via_box(self, path: str, query: pipeline_query.Query) -> dict:
|
||||
script = f"""
|
||||
import json, os
|
||||
path = {json.dumps(path)}
|
||||
offset = {offset}
|
||||
max_lines = {max_lines}
|
||||
max_bytes = {max_bytes}
|
||||
if not path.startswith('/workspace'):
|
||||
print(json.dumps({{'ok': False, 'error': 'Path must be under /workspace.'}}))
|
||||
elif not os.path.exists(path):
|
||||
print(json.dumps({{'ok': False, 'error': f'File not found: {{path}}'}}))
|
||||
elif os.path.isdir(path):
|
||||
entries = sorted(os.listdir(path))
|
||||
content = '\\n'.join(entries)
|
||||
print(json.dumps({{'ok': True, 'content': content, 'is_directory': True, 'total': len(entries), 'truncated': False}}))
|
||||
print(json.dumps({{'ok': True, 'content': '\\n'.join(sorted(os.listdir(path))), 'is_directory': True}}))
|
||||
else:
|
||||
lines = []
|
||||
output_bytes = 0
|
||||
end_line = offset - 1
|
||||
truncated = False
|
||||
next_offset = None
|
||||
with open(path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
for line_number, line in enumerate(f, 1):
|
||||
if line_number < offset:
|
||||
continue
|
||||
if len(lines) >= max_lines:
|
||||
truncated = True
|
||||
next_offset = line_number
|
||||
break
|
||||
line_bytes = len(line.encode('utf-8'))
|
||||
if output_bytes + line_bytes > max_bytes:
|
||||
truncated = True
|
||||
next_offset = line_number
|
||||
break
|
||||
lines.append(line.rstrip('\\n'))
|
||||
output_bytes += line_bytes
|
||||
end_line = line_number
|
||||
print(json.dumps({{
|
||||
'ok': True,
|
||||
'content': '\\n'.join(lines),
|
||||
'truncated': truncated,
|
||||
'start_line': offset,
|
||||
'end_line': end_line,
|
||||
'next_offset': next_offset,
|
||||
'max_lines': max_lines,
|
||||
'max_bytes': max_bytes,
|
||||
}}))
|
||||
print(json.dumps({{'ok': True, 'content': f.read()}}))
|
||||
""".strip()
|
||||
return await self._run_workspace_file_script(script, query)
|
||||
|
||||
@@ -353,27 +307,12 @@ else:
|
||||
if not any(part in skip_dirs for part in item.parts)
|
||||
]
|
||||
hits.sort(key=lambda item: item.stat().st_mtime if item.exists() else 0, reverse=True)
|
||||
shown = hits[:{_GLOB_MAX_MATCHES}]
|
||||
shown = hits[:100]
|
||||
matches = []
|
||||
output_bytes = 0
|
||||
truncated_by_bytes = False
|
||||
for item in shown:
|
||||
rel = os.path.relpath(str(item), path)
|
||||
sandbox_path = os.path.join(path, rel).replace(os.sep, '/')
|
||||
entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if matches else 0)
|
||||
if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}:
|
||||
truncated_by_bytes = True
|
||||
break
|
||||
matches.append(sandbox_path)
|
||||
output_bytes += entry_bytes
|
||||
print(json.dumps({{
|
||||
'ok': True,
|
||||
'matches': matches,
|
||||
'preview': '\\n'.join(matches),
|
||||
'total': len(hits),
|
||||
'truncated': len(hits) > len(matches) or truncated_by_bytes,
|
||||
'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if len(hits) > len(matches) else None),
|
||||
}}))
|
||||
matches.append(os.path.join(path, rel).replace(os.sep, '/'))
|
||||
print(json.dumps({{'ok': True, 'matches': matches, 'total': len(hits), 'truncated': len(hits) > 100}}))
|
||||
""".strip()
|
||||
return await self._run_workspace_file_script(script, query)
|
||||
|
||||
@@ -411,54 +350,29 @@ else:
|
||||
continue
|
||||
if item.is_file():
|
||||
files.append(item)
|
||||
if len(files) >= {_GREP_MAX_FILES}:
|
||||
if len(files) >= 5000:
|
||||
break
|
||||
|
||||
matches = []
|
||||
output_bytes = 0
|
||||
truncated_by = None
|
||||
for fp in files:
|
||||
try:
|
||||
handle = fp.open('r', encoding='utf-8', errors='ignore')
|
||||
text = fp.read_text(errors='ignore')
|
||||
except OSError:
|
||||
continue
|
||||
with handle:
|
||||
for lineno, line in enumerate(handle, 1):
|
||||
if regex.search(line):
|
||||
if base.is_file():
|
||||
file_path = path
|
||||
else:
|
||||
rel = os.path.relpath(str(fp), path)
|
||||
file_path = os.path.join(path, rel).replace(os.sep, '/')
|
||||
content = line.rstrip()
|
||||
line_truncated = False
|
||||
if len(content) > {_GREP_MAX_LINE_CHARS}:
|
||||
content = content[:{_GREP_MAX_LINE_CHARS}] + '... [truncated]'
|
||||
line_truncated = True
|
||||
entry = {{'file': file_path, 'line': lineno, 'content': content}}
|
||||
entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1
|
||||
if output_bytes + entry_bytes > {_DEFAULT_TOOL_RESULT_MAX_BYTES}:
|
||||
truncated_by = 'bytes'
|
||||
break
|
||||
if line_truncated and truncated_by is None:
|
||||
truncated_by = 'line'
|
||||
matches.append(entry)
|
||||
output_bytes += entry_bytes
|
||||
if len(matches) >= {_GREP_MAX_MATCHES}:
|
||||
truncated_by = truncated_by or 'matches'
|
||||
break
|
||||
if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}:
|
||||
break
|
||||
if truncated_by == 'bytes' or len(matches) >= {_GREP_MAX_MATCHES}:
|
||||
for lineno, line in enumerate(text.splitlines(), 1):
|
||||
if regex.search(line):
|
||||
if base.is_file():
|
||||
file_path = path
|
||||
else:
|
||||
rel = os.path.relpath(str(fp), path)
|
||||
file_path = os.path.join(path, rel).replace(os.sep, '/')
|
||||
matches.append({{'file': file_path, 'line': lineno, 'content': line.rstrip()}})
|
||||
if len(matches) >= 200:
|
||||
break
|
||||
if len(matches) >= 200:
|
||||
break
|
||||
|
||||
print(json.dumps({{
|
||||
'ok': True,
|
||||
'matches': matches,
|
||||
'total': len(matches),
|
||||
'truncated': truncated_by is not None,
|
||||
'truncated_by': truncated_by,
|
||||
}}))
|
||||
print(json.dumps({{'ok': True, 'matches': matches, 'total': len(matches), 'truncated': len(matches) >= 200}}))
|
||||
""".strip()
|
||||
return await self._run_workspace_file_script(script, query)
|
||||
|
||||
@@ -473,20 +387,14 @@ else:
|
||||
)
|
||||
if skill_request is not None and hasattr(self.ap.box_service, 'read_skill_file'):
|
||||
selected_skill, relative = skill_request
|
||||
host_path = self._resolve_skill_host_path(selected_skill, relative)
|
||||
if host_path and os.path.exists(host_path):
|
||||
if os.path.isdir(host_path):
|
||||
return self._build_directory_result(os.listdir(host_path))
|
||||
return self._read_text_file_preview(host_path, parameters)
|
||||
|
||||
try:
|
||||
result = await self.ap.box_service.read_skill_file(selected_skill['name'], relative)
|
||||
return self._build_read_result_from_text(str(result.get('content', '')), parameters)
|
||||
return {'ok': True, 'content': result.get('content', '')}
|
||||
except Exception:
|
||||
try:
|
||||
result = await self.ap.box_service.list_skill_files(selected_skill['name'], relative)
|
||||
entries = [entry['name'] for entry in result.get('entries', [])]
|
||||
return self._build_directory_result(entries)
|
||||
return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True}
|
||||
except Exception as exc:
|
||||
return {'ok': False, 'error': str(exc)}
|
||||
|
||||
@@ -497,13 +405,15 @@ else:
|
||||
include_activated=True,
|
||||
)
|
||||
if self._should_use_box_workspace_files(selected_skill):
|
||||
return await self._read_workspace_via_box(path, parameters, query)
|
||||
return await self._read_workspace_via_box(path, query)
|
||||
if not os.path.exists(host_path):
|
||||
return {'ok': False, 'error': f'File not found: {path}'}
|
||||
if os.path.isdir(host_path):
|
||||
entries = os.listdir(host_path)
|
||||
return self._build_directory_result(entries)
|
||||
return self._read_text_file_preview(host_path, parameters)
|
||||
return {'ok': True, 'content': '\n'.join(sorted(entries)), 'is_directory': True}
|
||||
with open(host_path, 'r', errors='replace') as f:
|
||||
content = f.read()
|
||||
return {'ok': True, 'content': content}
|
||||
|
||||
async def _invoke_write(self, parameters: dict, query: pipeline_query.Query) -> dict:
|
||||
path = parameters['path']
|
||||
@@ -674,28 +584,6 @@ else:
|
||||
'type': 'string',
|
||||
'description': 'Absolute path to the file (must be under /workspace).',
|
||||
},
|
||||
'offset': {
|
||||
'type': 'integer',
|
||||
'description': '1-indexed line number to start reading from. Defaults to 1.',
|
||||
'default': 1,
|
||||
'minimum': 1,
|
||||
},
|
||||
'limit': {
|
||||
'type': 'integer',
|
||||
'description': f'Maximum number of lines to return. Defaults to {_DEFAULT_READ_MAX_LINES}.',
|
||||
'default': _DEFAULT_READ_MAX_LINES,
|
||||
'minimum': 1,
|
||||
'maximum': _MAX_READ_MAX_LINES,
|
||||
},
|
||||
'max_bytes': {
|
||||
'type': 'integer',
|
||||
'description': (
|
||||
f'Maximum bytes of file content to return. Defaults to {_DEFAULT_TOOL_RESULT_MAX_BYTES}.'
|
||||
),
|
||||
'default': _DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
'minimum': 1,
|
||||
'maximum': _DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
},
|
||||
},
|
||||
'required': ['path'],
|
||||
'additionalProperties': False,
|
||||
@@ -852,30 +740,22 @@ else:
|
||||
hits.sort(key=lambda p: p.stat().st_mtime if p.exists() else 0, reverse=True)
|
||||
|
||||
total = len(hits)
|
||||
shown = hits[:_GLOB_MAX_MATCHES]
|
||||
shown = hits[:100]
|
||||
|
||||
# Convert back to sandbox paths
|
||||
sandbox_paths = []
|
||||
output_bytes = 0
|
||||
truncated_by_bytes = False
|
||||
for h in shown:
|
||||
rel = os.path.relpath(str(h), host_path)
|
||||
sandbox_path = os.path.join(path, rel)
|
||||
entry_bytes = len(sandbox_path.encode('utf-8')) + (1 if sandbox_paths else 0)
|
||||
if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES:
|
||||
truncated_by_bytes = True
|
||||
break
|
||||
sandbox_paths.append(sandbox_path)
|
||||
output_bytes += entry_bytes
|
||||
|
||||
return {
|
||||
'ok': True,
|
||||
'matches': sandbox_paths,
|
||||
'preview': '\n'.join(sandbox_paths),
|
||||
'total': total,
|
||||
'truncated': total > len(sandbox_paths) or truncated_by_bytes,
|
||||
'truncated_by': 'bytes' if truncated_by_bytes else ('matches' if total > len(sandbox_paths) else None),
|
||||
}
|
||||
result_lines = sandbox_paths
|
||||
result = '\n'.join(result_lines)
|
||||
|
||||
if total > 100:
|
||||
result += f'\n... ({total} matches, showing first 100)'
|
||||
|
||||
return {'ok': True, 'matches': result_lines, 'total': total, 'truncated': total > 100}
|
||||
|
||||
async def _invoke_grep(self, parameters: dict, query: pipeline_query.Query) -> dict:
|
||||
pattern = parameters['pattern']
|
||||
@@ -911,46 +791,32 @@ else:
|
||||
files = self._grep_walk(base, include)
|
||||
|
||||
matches = []
|
||||
output_bytes = 0
|
||||
truncated_by = None
|
||||
for fp in files:
|
||||
try:
|
||||
handle = fp.open('r', encoding='utf-8', errors='ignore')
|
||||
text = fp.read_text(errors='ignore')
|
||||
except OSError:
|
||||
continue
|
||||
with handle:
|
||||
for lineno, line in enumerate(handle, 1):
|
||||
if regex.search(line):
|
||||
rel = os.path.relpath(str(fp), host_path)
|
||||
sandbox_path = os.path.join(path, rel)
|
||||
content, line_truncated = self._truncate_grep_line(line.rstrip())
|
||||
entry = {
|
||||
for lineno, line in enumerate(text.splitlines(), 1):
|
||||
if regex.search(line):
|
||||
rel = os.path.relpath(str(fp), host_path)
|
||||
sandbox_path = os.path.join(path, rel)
|
||||
matches.append(
|
||||
{
|
||||
'file': sandbox_path,
|
||||
'line': lineno,
|
||||
'content': content,
|
||||
'content': line.rstrip(),
|
||||
}
|
||||
entry_bytes = len(json.dumps(entry, ensure_ascii=False).encode('utf-8')) + 1
|
||||
if output_bytes + entry_bytes > _DEFAULT_TOOL_RESULT_MAX_BYTES:
|
||||
truncated_by = 'bytes'
|
||||
break
|
||||
if line_truncated and truncated_by is None:
|
||||
truncated_by = 'line'
|
||||
matches.append(entry)
|
||||
output_bytes += entry_bytes
|
||||
if len(matches) >= _GREP_MAX_MATCHES:
|
||||
truncated_by = truncated_by or 'matches'
|
||||
break
|
||||
if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES:
|
||||
break
|
||||
if truncated_by == 'bytes' or len(matches) >= _GREP_MAX_MATCHES:
|
||||
)
|
||||
if len(matches) >= 200:
|
||||
break
|
||||
if len(matches) >= 200:
|
||||
break
|
||||
|
||||
return {
|
||||
'ok': True,
|
||||
'matches': matches,
|
||||
'total': len(matches),
|
||||
'truncated': truncated_by is not None,
|
||||
'truncated_by': truncated_by,
|
||||
'truncated': len(matches) >= 200,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -962,207 +828,10 @@ else:
|
||||
continue
|
||||
if item.is_file():
|
||||
results.append(item)
|
||||
if len(results) >= _GREP_MAX_FILES:
|
||||
if len(results) >= 5000:
|
||||
break
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _resolve_skill_host_path(selected_skill: dict, relative: str) -> str | None:
|
||||
package_root = str(selected_skill.get('package_root', '') or '').strip()
|
||||
if not package_root:
|
||||
return None
|
||||
|
||||
host_root = os.path.realpath(package_root)
|
||||
host_path = os.path.realpath(os.path.join(host_root, relative))
|
||||
if not (host_path == host_root or host_path.startswith(host_root + os.sep)):
|
||||
raise ValueError('Path escapes the skill package boundary.')
|
||||
return host_path
|
||||
|
||||
def _normalize_exec_result(self, result: dict) -> dict:
|
||||
normalized = dict(result)
|
||||
stdout = str(normalized.get('stdout') or '')
|
||||
stderr = str(normalized.get('stderr') or '')
|
||||
stdout, stdout_capped = self._truncate_text_to_bytes_with_flag(stdout, _DEFAULT_TOOL_RESULT_MAX_BYTES)
|
||||
stderr, stderr_capped = self._truncate_text_to_bytes_with_flag(stderr, _DEFAULT_TOOL_RESULT_MAX_BYTES)
|
||||
normalized['stdout'] = stdout
|
||||
normalized['stderr'] = stderr
|
||||
normalized['stdout_truncated'] = bool(normalized.get('stdout_truncated') or stdout_capped)
|
||||
normalized['stderr_truncated'] = bool(normalized.get('stderr_truncated') or stderr_capped)
|
||||
|
||||
if stdout and stderr:
|
||||
preview_raw = f'stdout:\n{stdout}\n\nstderr:\n{stderr}'
|
||||
else:
|
||||
preview_raw = stdout or stderr
|
||||
preview, preview_capped = self._truncate_text_to_bytes_with_flag(preview_raw, _DEFAULT_TOOL_RESULT_MAX_BYTES)
|
||||
normalized['preview'] = preview
|
||||
normalized['truncated'] = bool(
|
||||
normalized['stdout_truncated'] or normalized['stderr_truncated'] or preview_capped
|
||||
)
|
||||
if preview_capped and not normalized.get('truncated_by'):
|
||||
normalized['truncated_by'] = 'bytes'
|
||||
return normalized
|
||||
|
||||
def _build_directory_result(self, entries: list[str]) -> dict:
|
||||
sorted_entries = sorted(str(entry) for entry in entries)
|
||||
content = '\n'.join(sorted_entries)
|
||||
preview = self._truncate_text_to_bytes(content, _DEFAULT_TOOL_RESULT_MAX_BYTES)
|
||||
truncated = preview != content
|
||||
return {
|
||||
'ok': True,
|
||||
'content': preview,
|
||||
'is_directory': True,
|
||||
'total': len(sorted_entries),
|
||||
'truncated': truncated,
|
||||
'truncated_by': 'bytes' if truncated else None,
|
||||
}
|
||||
|
||||
def _read_text_file_preview(self, host_path: str, parameters: dict) -> dict:
|
||||
offset = self._positive_int(parameters.get('offset'), default=1)
|
||||
max_lines = self._positive_int(
|
||||
parameters.get('limit'),
|
||||
default=_DEFAULT_READ_MAX_LINES,
|
||||
max_value=_MAX_READ_MAX_LINES,
|
||||
)
|
||||
max_bytes = self._positive_int(
|
||||
parameters.get('max_bytes'),
|
||||
default=_DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
)
|
||||
lines: list[str] = []
|
||||
output_bytes = 0
|
||||
end_line = offset - 1
|
||||
truncated = False
|
||||
truncated_by: str | None = None
|
||||
next_offset: int | None = None
|
||||
|
||||
with open(host_path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
for line_number, line in enumerate(f, 1):
|
||||
if line_number < offset:
|
||||
continue
|
||||
if len(lines) >= max_lines:
|
||||
truncated = True
|
||||
truncated_by = 'lines'
|
||||
next_offset = line_number
|
||||
break
|
||||
|
||||
line_bytes = len(line.encode('utf-8'))
|
||||
if output_bytes + line_bytes > max_bytes:
|
||||
truncated = True
|
||||
truncated_by = 'bytes'
|
||||
next_offset = line_number
|
||||
break
|
||||
|
||||
lines.append(line.rstrip('\n'))
|
||||
output_bytes += line_bytes
|
||||
end_line = line_number
|
||||
|
||||
if not lines and truncated_by == 'bytes':
|
||||
content = (
|
||||
f'[Line {next_offset or offset} exceeds the {self._format_size(max_bytes)} read limit. '
|
||||
'Use exec with a byte-range command for this line, or read a different offset.]'
|
||||
)
|
||||
else:
|
||||
content = '\n'.join(lines)
|
||||
|
||||
return {
|
||||
'ok': True,
|
||||
'content': content,
|
||||
'truncated': truncated,
|
||||
'truncated_by': truncated_by,
|
||||
'start_line': offset,
|
||||
'end_line': end_line,
|
||||
'next_offset': next_offset,
|
||||
'max_lines': max_lines,
|
||||
'max_bytes': max_bytes,
|
||||
}
|
||||
|
||||
def _build_read_result_from_text(self, content: str, parameters: dict) -> dict:
|
||||
offset = self._positive_int(parameters.get('offset'), default=1)
|
||||
max_lines = self._positive_int(
|
||||
parameters.get('limit'),
|
||||
default=_DEFAULT_READ_MAX_LINES,
|
||||
max_value=_MAX_READ_MAX_LINES,
|
||||
)
|
||||
max_bytes = self._positive_int(
|
||||
parameters.get('max_bytes'),
|
||||
default=_DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
max_value=_DEFAULT_TOOL_RESULT_MAX_BYTES,
|
||||
)
|
||||
all_lines = content.splitlines()
|
||||
start_index = offset - 1
|
||||
if start_index >= len(all_lines) and all_lines:
|
||||
return {'ok': False, 'error': f'Offset {offset} is beyond end of file ({len(all_lines)} lines total)'}
|
||||
output_lines: list[str] = []
|
||||
output_bytes = 0
|
||||
truncated = False
|
||||
truncated_by: str | None = None
|
||||
next_offset: int | None = None
|
||||
for index, line in enumerate(all_lines[start_index:], start_index + 1):
|
||||
if len(output_lines) >= max_lines:
|
||||
truncated = True
|
||||
truncated_by = 'lines'
|
||||
next_offset = index
|
||||
break
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if output_lines else 0)
|
||||
if output_bytes + line_bytes > max_bytes:
|
||||
truncated = True
|
||||
truncated_by = 'bytes'
|
||||
next_offset = index
|
||||
break
|
||||
output_lines.append(line)
|
||||
output_bytes += line_bytes
|
||||
|
||||
end_line = offset + len(output_lines) - 1
|
||||
return {
|
||||
'ok': True,
|
||||
'content': '\n'.join(output_lines),
|
||||
'truncated': truncated,
|
||||
'truncated_by': truncated_by,
|
||||
'start_line': offset,
|
||||
'end_line': end_line,
|
||||
'next_offset': next_offset,
|
||||
'max_lines': max_lines,
|
||||
'max_bytes': max_bytes,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _positive_int(value, *, default: int, max_value: int | None = None) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
parsed = default
|
||||
if parsed <= 0:
|
||||
parsed = default
|
||||
if max_value is not None:
|
||||
parsed = min(parsed, max_value)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _truncate_grep_line(line: str) -> tuple[str, bool]:
|
||||
if len(line) <= _GREP_MAX_LINE_CHARS:
|
||||
return line, False
|
||||
return f'{line[:_GREP_MAX_LINE_CHARS]}... [truncated]', True
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text_to_bytes(text: str, max_bytes: int) -> str:
|
||||
return NativeToolLoader._truncate_text_to_bytes_with_flag(text, max_bytes)[0]
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text_to_bytes_with_flag(text: str, max_bytes: int) -> tuple[str, bool]:
|
||||
data = text.encode('utf-8')
|
||||
if len(data) <= max_bytes:
|
||||
return text, False
|
||||
truncated = data[:max_bytes]
|
||||
while truncated and (truncated[-1] & 0xC0) == 0x80:
|
||||
truncated = truncated[:-1]
|
||||
return truncated.decode('utf-8', errors='ignore'), True
|
||||
|
||||
@staticmethod
|
||||
def _format_size(bytes_count: int) -> str:
|
||||
if bytes_count < 1024:
|
||||
return f'{bytes_count}B'
|
||||
return f'{bytes_count / 1024:.1f}KB'
|
||||
|
||||
def _summarize_parameters(self, parameters: dict) -> dict:
|
||||
summary = dict(parameters)
|
||||
cmd = str(summary.get('command', '')).strip()
|
||||
|
||||
@@ -72,45 +72,6 @@ def register_activated_skill(query: pipeline_query.Query, skill_data: dict) -> N
|
||||
activated[skill_name] = skill_data
|
||||
|
||||
|
||||
def normalize_skill_names(value: typing.Any) -> list[str]:
|
||||
"""Return a de-duplicated list of non-empty skill names."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
names: list[str] = []
|
||||
for item in value:
|
||||
skill_name = str(item or '').strip()
|
||||
if skill_name and skill_name not in names:
|
||||
names.append(skill_name)
|
||||
return names
|
||||
|
||||
|
||||
def get_activated_skill_names(query: pipeline_query.Query) -> list[str]:
|
||||
"""Return activated skill names for callers that own persistence policy."""
|
||||
return normalize_skill_names(list(get_activated_skills(query).keys()))
|
||||
|
||||
|
||||
def restore_activated_skills(
|
||||
ap: app.Application,
|
||||
query: pipeline_query.Query,
|
||||
skill_names: typing.Any,
|
||||
) -> list[str]:
|
||||
"""Restore caller-provided activated skill names into Query variables.
|
||||
|
||||
Persistence and state scope ownership belong to higher-level flows. This
|
||||
helper only rebuilds current Query state from pipeline-visible skills, so
|
||||
removed or unbound skills stay unavailable to native exec/write/edit.
|
||||
"""
|
||||
restored: list[str] = []
|
||||
for skill_name in normalize_skill_names(skill_names):
|
||||
skill_data = get_visible_skill(ap, query, skill_name)
|
||||
if skill_data is None:
|
||||
continue
|
||||
register_activated_skill(query, skill_data)
|
||||
restored.append(skill_name)
|
||||
return restored
|
||||
|
||||
|
||||
def parse_skill_mount_path(sandbox_path: str) -> tuple[str | None, str]:
|
||||
normalized_path = str(sandbox_path or '/workspace').strip() or '/workspace'
|
||||
if normalized_path == SKILL_MOUNT_PREFIX:
|
||||
|
||||
@@ -6,7 +6,6 @@ import typing
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
from .. import loader
|
||||
from .availability import is_box_backend_available
|
||||
|
||||
# Align with Claude Code's Skill tool design:
|
||||
# - activate: Activate a skill via Tool Call, returns SKILL.md content
|
||||
@@ -46,7 +45,18 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
|
||||
async def _check_sandbox_available(self) -> bool:
|
||||
"""Check if the box backend is truly available (not just the runtime)."""
|
||||
return await is_box_backend_available(self.ap)
|
||||
box_service = getattr(self.ap, 'box_service', None)
|
||||
if box_service is None:
|
||||
return False
|
||||
if not getattr(box_service, 'available', False):
|
||||
return False
|
||||
# Check if backend is truly available via get_status
|
||||
try:
|
||||
status = await box_service.get_status()
|
||||
backend_info = status.get('backend', {})
|
||||
return backend_info.get('available', False)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_tools(self, bound_plugins: list[str] | None = None) -> list[resource_tool.LLMTool]:
|
||||
if not self._is_available():
|
||||
@@ -82,15 +92,16 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
if not skill_name:
|
||||
raise ValueError('skill_name is required')
|
||||
|
||||
from . import skill as skill_loader
|
||||
|
||||
skill_data = skill_loader.get_visible_skill(self.ap, query, skill_name)
|
||||
skill_mgr = self.ap.skill_mgr
|
||||
skill_data = skill_mgr.get_skill_by_name(skill_name)
|
||||
if skill_data is None:
|
||||
visible_skills = skill_loader.get_visible_skills(self.ap, query)
|
||||
visible_skills = getattr(skill_mgr, 'skills', {})
|
||||
available_names = ', '.join(sorted(visible_skills.keys())) or 'none'
|
||||
raise ValueError(f'Skill "{skill_name}" not found. Available skills: {available_names}')
|
||||
|
||||
# Register activated skill for sandbox mount path resolution
|
||||
from . import skill as skill_loader
|
||||
|
||||
skill_loader.register_activated_skill(query, skill_data)
|
||||
|
||||
# Return SKILL.md content as Tool Result (injects into context)
|
||||
@@ -116,7 +127,6 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
'activated': True,
|
||||
'skill_name': skill_name,
|
||||
'mount_path': mount_path,
|
||||
'activated_skill_names': skill_loader.get_activated_skill_names(query),
|
||||
'content': result_content,
|
||||
}
|
||||
|
||||
@@ -191,13 +201,13 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
return resource_tool.LLMTool(
|
||||
name=ACTIVATE_SKILL_TOOL_NAME,
|
||||
human_desc='Activate a skill',
|
||||
description='Activate a pipeline-visible skill by name and return its instructions as a tool result.',
|
||||
description=self._build_activate_tool_description(),
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'skill_name': {
|
||||
'type': 'string',
|
||||
'description': 'The skill name to activate.',
|
||||
'description': 'The skill name to activate (no arguments). E.g., "pdf" or "data-analysis"',
|
||||
},
|
||||
},
|
||||
'required': ['skill_name'],
|
||||
@@ -245,3 +255,50 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
},
|
||||
func=lambda parameters: parameters,
|
||||
)
|
||||
|
||||
def _build_activate_tool_description(self) -> str:
|
||||
"""Build tool description with embedded available_skills list."""
|
||||
skill_mgr = getattr(self.ap, 'skill_mgr', None)
|
||||
if skill_mgr is None:
|
||||
return 'Activate a skill. No skills are currently available.'
|
||||
|
||||
skills = getattr(skill_mgr, 'skills', {})
|
||||
if not skills:
|
||||
return 'Activate a skill. No skills are currently available.'
|
||||
|
||||
# Build <available_skills> section
|
||||
available_skills_lines = ['<available_skills>']
|
||||
for skill_name, skill_data in sorted(skills.items()):
|
||||
description = skill_data.get('description', '')
|
||||
available_skills_lines.append('<skill>')
|
||||
available_skills_lines.append(f'<name>{skill_name}</name>')
|
||||
available_skills_lines.append(f'<description>{description}</description>')
|
||||
available_skills_lines.append('</skill>')
|
||||
available_skills_lines.append('</available_skills>')
|
||||
|
||||
available_skills_block = '\n'.join(available_skills_lines)
|
||||
|
||||
return f"""Activate a skill within the main conversation.
|
||||
|
||||
<skills_instructions>
|
||||
When users ask you to perform tasks, check if any of the available skills
|
||||
below can help complete the task more effectively. Skills provide specialized
|
||||
capabilities and domain knowledge.
|
||||
|
||||
How to use skills:
|
||||
- Invoke skills using this tool with the skill name only (no arguments)
|
||||
- When you invoke a skill, you will see <command-message>
|
||||
The skill is activated
|
||||
</command-message>
|
||||
- The skill's instructions will be provided in the tool result
|
||||
- Examples:
|
||||
- skill_name: "pdf" - invoke the pdf skill
|
||||
- skill_name: "data-analysis" - invoke the data-analysis skill
|
||||
|
||||
Important:
|
||||
- Only use skills listed in <available_skills> below
|
||||
- Do not invoke a skill that is already running
|
||||
- To create a new skill: prepare it in /workspace, then use register_skill tool
|
||||
</skills_instructions>
|
||||
|
||||
{available_skills_block}"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
import uuid
|
||||
import zipfile
|
||||
import io
|
||||
import datetime
|
||||
from typing import Any
|
||||
from langbot.pkg.core import app
|
||||
import sqlalchemy
|
||||
@@ -25,6 +26,10 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
super().__init__(ap)
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
|
||||
@staticmethod
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@@ -334,6 +339,24 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
# are passed directly to vector_search by some plugins (e.g. LangRAG)
|
||||
# and would cause empty results when the metadata field doesn't exist.
|
||||
filters = settings.pop('filters', {})
|
||||
trace_context = settings.pop('_trace_context', None)
|
||||
host_span_started_at = self._utc_now()
|
||||
host_span_id = None
|
||||
if trace_context and trace_context.get('trace_id'):
|
||||
host_parent_span_id = trace_context.get('parent_span_id')
|
||||
host_span_id = trace_context.get('rag_span_id') or f'span-{uuid.uuid4().hex[:16]}'
|
||||
trace_context = {
|
||||
'trace_id': trace_context.get('trace_id'),
|
||||
'parent_span_id': host_span_id,
|
||||
'host_parent_span_id': host_parent_span_id,
|
||||
'message_id': trace_context.get('message_id'),
|
||||
'query_id': trace_context.get('query_id'),
|
||||
'session_id': trace_context.get('session_id'),
|
||||
'bot_id': trace_context.get('bot_id'),
|
||||
'pipeline_id': trace_context.get('pipeline_id'),
|
||||
'knowledge_base_id': kb.uuid,
|
||||
'attributes': trace_context.get('attributes') or {},
|
||||
}
|
||||
|
||||
retrieval_context = {
|
||||
'query': query,
|
||||
@@ -343,13 +366,104 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface):
|
||||
'creation_settings': kb.creation_settings or {},
|
||||
'filters': filters,
|
||||
}
|
||||
if trace_context:
|
||||
retrieval_context['trace_context'] = trace_context
|
||||
|
||||
result = await self.ap.plugin_connector.call_rag_retrieve(
|
||||
plugin_id,
|
||||
retrieval_context,
|
||||
)
|
||||
try:
|
||||
result = await self.ap.plugin_connector.call_rag_retrieve(
|
||||
plugin_id,
|
||||
retrieval_context,
|
||||
)
|
||||
except Exception as e:
|
||||
if trace_context:
|
||||
await self._record_rag_trace_result(
|
||||
trace_context=trace_context,
|
||||
host_span_id=host_span_id,
|
||||
started_at=host_span_started_at,
|
||||
plugin_id=plugin_id,
|
||||
result={
|
||||
'results': [],
|
||||
'metadata': {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
},
|
||||
},
|
||||
)
|
||||
raise
|
||||
if trace_context:
|
||||
await self._record_rag_trace_result(
|
||||
trace_context=trace_context,
|
||||
host_span_id=host_span_id,
|
||||
started_at=host_span_started_at,
|
||||
plugin_id=plugin_id,
|
||||
result=result,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _record_rag_trace_result(
|
||||
self,
|
||||
trace_context: dict[str, Any],
|
||||
host_span_id: str | None,
|
||||
started_at: datetime.datetime,
|
||||
plugin_id: str,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Persist host RAG span and plugin-provided child spans."""
|
||||
trace_id = trace_context.get('trace_id')
|
||||
if not trace_id:
|
||||
return
|
||||
|
||||
metadata = result.get('metadata') if isinstance(result, dict) else {}
|
||||
metadata = metadata if isinstance(metadata, dict) else {}
|
||||
plugin_spans = metadata.get('trace_spans') if isinstance(metadata.get('trace_spans'), list) else []
|
||||
parent_span_id = trace_context.get('parent_span_id')
|
||||
host_parent_span_id = trace_context.get('host_parent_span_id')
|
||||
|
||||
try:
|
||||
await self.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
span_id=host_span_id,
|
||||
parent_span_id=host_parent_span_id,
|
||||
name=f'Knowledge retrieval {self.knowledge_base_entity.name}',
|
||||
kind='rag.retrieval',
|
||||
status=metadata.get('status', 'success'),
|
||||
started_at=started_at,
|
||||
duration=metadata.get('duration_ms'),
|
||||
message_id=trace_context.get('message_id'),
|
||||
session_id=trace_context.get('session_id'),
|
||||
bot_id=trace_context.get('bot_id'),
|
||||
pipeline_id=trace_context.get('pipeline_id'),
|
||||
attributes={
|
||||
'knowledge_base_id': self.knowledge_base_entity.uuid,
|
||||
'knowledge_base_name': self.knowledge_base_entity.name,
|
||||
'plugin_id': plugin_id,
|
||||
'returned_count': len(result.get('results', []) if isinstance(result, dict) else []),
|
||||
'total_found': result.get('total_found') if isinstance(result, dict) else None,
|
||||
},
|
||||
error_message=metadata.get('error_message'),
|
||||
)
|
||||
for span in plugin_spans:
|
||||
if not isinstance(span, dict):
|
||||
continue
|
||||
await self.ap.monitoring_service.record_span(
|
||||
trace_id=trace_id,
|
||||
span_id=span.get('span_id'),
|
||||
parent_span_id=span.get('parent_span_id') or host_span_id or parent_span_id,
|
||||
name=span.get('name') or 'RAG plugin stage',
|
||||
kind=span.get('kind') or 'rag.stage',
|
||||
status=span.get('status') or 'success',
|
||||
started_at=started_at,
|
||||
duration=span.get('duration_ms'),
|
||||
message_id=trace_context.get('message_id'),
|
||||
session_id=trace_context.get('session_id'),
|
||||
bot_id=trace_context.get('bot_id'),
|
||||
pipeline_id=trace_context.get('pipeline_id'),
|
||||
attributes=span.get('attributes') if isinstance(span.get('attributes'), dict) else {},
|
||||
error_message=span.get('error_message'),
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to record RAG trace spans: {e}')
|
||||
|
||||
async def _delete_document(self, document_id: str) -> bool:
|
||||
"""Call plugin to delete document."""
|
||||
kb = self.knowledge_base_entity
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# LangBot Test Suite
|
||||
|
||||
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
|
||||
This directory contains the LangBot backend test suite, including unit tests,
|
||||
integration tests, startup E2E tests, and container-backed Box runtime tests.
|
||||
|
||||
## Quality Gate Layers
|
||||
|
||||
@@ -10,10 +11,15 @@ LangBot uses a layered quality gate system for developers and CI:
|
||||
|-------|---------|--------------|-------------|
|
||||
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
|
||||
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
|
||||
| **Backend E2E** | `uv run --python 3.12 pytest tests/e2e -q --tb=short` | Starts a real LangBot process with minimal config | Before release, CI |
|
||||
| **Box Integration** | `uv run --python 3.12 pytest tests/integration_tests -q --tb=short` | Real Box sandbox/runtime integration | Before Box/runtime changes, CI |
|
||||
| **Frontend E2E** | `cd web && pnpm test:e2e` | Playwright smoke tests with mocked backend and Space APIs | Before web changes, CI |
|
||||
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
|
||||
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
|
||||
|
||||
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
|
||||
**Note**: PostgreSQL migration tests and slow tests are NOT in local default
|
||||
gates. They run in separate CI workflows. Frontend Playwright tests live under
|
||||
`web/tests/e2e` and are documented in `web/README.md`.
|
||||
|
||||
### Developer Workflow
|
||||
|
||||
@@ -28,6 +34,9 @@ make test-all-local
|
||||
bash scripts/test-quick.sh # ~2 min
|
||||
bash scripts/test-integration-fast.sh # ~3 min
|
||||
bash scripts/test-coverage.sh # ~8 min
|
||||
uv run --python 3.12 pytest tests/e2e -q --tb=short
|
||||
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
|
||||
cd web && pnpm test:e2e
|
||||
```
|
||||
|
||||
### Coverage Baseline
|
||||
@@ -70,6 +79,12 @@ tests/
|
||||
│ └── persistence/ # Database/persistence tests
|
||||
│ ├── __init__.py
|
||||
│ └── test_migrations.py # Alembic migration tests
|
||||
├── e2e/ # Real LangBot startup E2E tests
|
||||
│ ├── conftest.py
|
||||
│ ├── test_startup.py
|
||||
│ └── utils/
|
||||
├── integration_tests/ # Container-backed integration tests
|
||||
│ └── box/ # Box runtime and MCP process tests
|
||||
├── smoke/ # Smoke tests (quick validation)
|
||||
│ └── test_fake_message_flow.py
|
||||
├── unit_tests/ # Unit tests
|
||||
@@ -303,6 +318,44 @@ These tests:
|
||||
- Test prevent_default, exception handling, and full message flow
|
||||
- Do not require real LLM provider keys
|
||||
|
||||
### Running backend E2E startup tests
|
||||
|
||||
Backend E2E tests start a real LangBot process with a generated minimal
|
||||
`data/config.yaml`, SQLite database, local storage, and embedded Chroma path.
|
||||
They do not require provider keys or external services.
|
||||
|
||||
```bash
|
||||
uv run --python 3.12 pytest tests/e2e -q --tb=short
|
||||
```
|
||||
|
||||
These tests verify startup orchestration, migrations, API route registration,
|
||||
and the minimal no-LLM startup path. The E2E process manager disables ambient
|
||||
proxy variables for subprocess startup and uses direct localhost HTTP clients,
|
||||
so local proxy settings should not affect the health checks.
|
||||
|
||||
### Running Box integration tests
|
||||
|
||||
Box integration tests exercise the real sandbox runtime path, including command
|
||||
execution, session persistence, managed process WebSocket attachment, and
|
||||
cleanup behavior.
|
||||
|
||||
```bash
|
||||
uv run --python 3.12 pytest tests/integration_tests -q --tb=short
|
||||
```
|
||||
|
||||
These tests require a working Docker or Podman runtime. In CI, the dedicated
|
||||
Box integration job checks Docker availability before running the tests.
|
||||
|
||||
### Running frontend E2E tests
|
||||
|
||||
Frontend E2E tests live in `web/tests/e2e` and use Playwright. They start Vite
|
||||
and mock the LangBot backend and Space APIs, so no backend process is required.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm test:e2e
|
||||
```
|
||||
|
||||
### Known Issues
|
||||
|
||||
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
|
||||
@@ -320,6 +373,9 @@ Tests are automatically run on:
|
||||
- Push to master/develop branches
|
||||
|
||||
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
|
||||
Startup E2E and Box integration tests run as separate Python 3.12 jobs because
|
||||
they exercise process/container behavior instead of pure Python compatibility.
|
||||
Frontend Playwright smoke tests run in `.github/workflows/frontend-tests.yml`.
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
@@ -406,4 +462,4 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
|
||||
- [ ] Add E2E tests
|
||||
- [ ] Add performance benchmarks
|
||||
- [ ] Add mutation testing for better coverage quality
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
- [ ] Add property-based testing with Hypothesis
|
||||
|
||||
@@ -92,11 +92,11 @@ def e2e_client(e2e_port, langbot_process):
|
||||
|
||||
base_url = f'http://127.0.0.1:{e2e_port}'
|
||||
|
||||
with httpx.Client(base_url=base_url, timeout=10.0) as client:
|
||||
with httpx.Client(base_url=base_url, timeout=10.0, trust_env=False) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def e2e_db_path(e2e_tmpdir):
|
||||
"""Path to SQLite database file."""
|
||||
return e2e_tmpdir / 'data' / 'langbot.db'
|
||||
return e2e_tmpdir / 'data' / 'langbot.db'
|
||||
|
||||
@@ -38,12 +38,13 @@ class TestStartupFlow:
|
||||
# System info should contain version info
|
||||
assert 'version' in data['data'] or 'edition' in data['data']
|
||||
|
||||
def test_database_initialized(self, e2e_db_path):
|
||||
def test_database_initialized(self, langbot_process, e2e_db_path):
|
||||
"""Verify SQLite database was created and initialized."""
|
||||
assert e2e_db_path.exists()
|
||||
|
||||
# Database should have some tables after migration
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(e2e_db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -74,10 +75,13 @@ class TestStartupFlow:
|
||||
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
|
||||
"""Test auth endpoint."""
|
||||
# First startup may allow initial setup
|
||||
response = e2e_client.post('/api/v1/user/auth', json={
|
||||
'username': 'admin',
|
||||
'password': 'admin',
|
||||
})
|
||||
response = e2e_client.post(
|
||||
'/api/v1/user/auth',
|
||||
json={
|
||||
'user': 'admin',
|
||||
'password': 'admin',
|
||||
},
|
||||
)
|
||||
|
||||
# Response could be:
|
||||
# - 200 if auth succeeds
|
||||
@@ -94,9 +98,10 @@ class TestStartupStages:
|
||||
# If API responds on e2e_port, config was loaded
|
||||
assert e2e_client.get('/api/v1/system/info').status_code == 200
|
||||
|
||||
def test_migrations_applied(self, e2e_db_path):
|
||||
def test_migrations_applied(self, langbot_process, e2e_db_path):
|
||||
"""Verify database migrations were applied."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(e2e_db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
@@ -176,4 +176,4 @@ def create_test_directories(tmpdir: Path) -> dict[str, Path]:
|
||||
for path in directories.values():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return directories
|
||||
return directories
|
||||
|
||||
@@ -44,6 +44,17 @@ class LangBotProcess:
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = str(self.project_root / 'src')
|
||||
for proxy_key in (
|
||||
'HTTP_PROXY',
|
||||
'HTTPS_PROXY',
|
||||
'ALL_PROXY',
|
||||
'http_proxy',
|
||||
'https_proxy',
|
||||
'all_proxy',
|
||||
):
|
||||
env.pop(proxy_key, None)
|
||||
env['NO_PROXY'] = '127.0.0.1,localhost'
|
||||
env['no_proxy'] = '127.0.0.1,localhost'
|
||||
|
||||
# Set API port via environment variable
|
||||
env['API__PORT'] = str(self.port)
|
||||
@@ -79,9 +90,11 @@ precision = 2
|
||||
f.write(coveragerc_content)
|
||||
|
||||
cmd = [
|
||||
'coverage', 'run',
|
||||
'coverage',
|
||||
'run',
|
||||
'--rcfile=' + str(coveragerc_path),
|
||||
'-m', 'langbot',
|
||||
'-m',
|
||||
'langbot',
|
||||
]
|
||||
else:
|
||||
cmd = ['uv', 'run', 'python', '-m', 'langbot']
|
||||
@@ -113,6 +126,8 @@ precision = 2
|
||||
r = httpx.get(
|
||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||
timeout=2.0,
|
||||
follow_redirects=False,
|
||||
trust_env=False,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
logger.info(f'LangBot started successfully on port {self.port}')
|
||||
@@ -185,6 +200,8 @@ precision = 2
|
||||
r = httpx.get(
|
||||
f'http://127.0.0.1:{self.port}/api/v1/system/info',
|
||||
timeout=5.0,
|
||||
follow_redirects=False,
|
||||
trust_env=False,
|
||||
)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
@@ -201,4 +218,4 @@ def find_project_root() -> Path:
|
||||
return parent
|
||||
|
||||
# Fallback to LangBot-test-build directory
|
||||
return Path('/home/glwuy/langbot-app/LangBot-test-build')
|
||||
return Path('/home/glwuy/langbot-app/LangBot-test-build')
|
||||
|
||||
@@ -58,45 +58,45 @@ from tests.factories.platform import (
|
||||
|
||||
__all__ = [
|
||||
# App
|
||||
"FakeApp",
|
||||
"fake_app",
|
||||
'FakeApp',
|
||||
'fake_app',
|
||||
# Message chains
|
||||
"text_chain",
|
||||
"group_text_chain",
|
||||
"mention_chain",
|
||||
"image_chain",
|
||||
'text_chain',
|
||||
'group_text_chain',
|
||||
'mention_chain',
|
||||
'image_chain',
|
||||
# Message events
|
||||
"friend_message_event",
|
||||
"group_message_event",
|
||||
'friend_message_event',
|
||||
'group_message_event',
|
||||
# Mock adapters
|
||||
"mock_adapter",
|
||||
'mock_adapter',
|
||||
# Queries
|
||||
"text_query",
|
||||
"group_text_query",
|
||||
"private_text_query",
|
||||
"command_query",
|
||||
"mention_query",
|
||||
"empty_query",
|
||||
"image_query",
|
||||
"file_query",
|
||||
"unsupported_query",
|
||||
"voice_query",
|
||||
"at_all_query",
|
||||
"query_with_session",
|
||||
"query_with_config",
|
||||
'text_query',
|
||||
'group_text_query',
|
||||
'private_text_query',
|
||||
'command_query',
|
||||
'mention_query',
|
||||
'empty_query',
|
||||
'image_query',
|
||||
'file_query',
|
||||
'unsupported_query',
|
||||
'voice_query',
|
||||
'at_all_query',
|
||||
'query_with_session',
|
||||
'query_with_config',
|
||||
# Provider
|
||||
"FakeProvider",
|
||||
"fake_provider",
|
||||
"fake_provider_pong",
|
||||
"fake_provider_timeout",
|
||||
"fake_provider_auth_error",
|
||||
"fake_provider_rate_limit",
|
||||
"fake_provider_malformed",
|
||||
"fake_model",
|
||||
'FakeProvider',
|
||||
'fake_provider',
|
||||
'fake_provider_pong',
|
||||
'fake_provider_timeout',
|
||||
'fake_provider_auth_error',
|
||||
'fake_provider_rate_limit',
|
||||
'fake_provider_malformed',
|
||||
'fake_model',
|
||||
# Platform
|
||||
"FakePlatform",
|
||||
"fake_platform",
|
||||
"fake_platform_with_streaming",
|
||||
"fake_platform_with_failure",
|
||||
"mock_platform_adapter",
|
||||
]
|
||||
'FakePlatform',
|
||||
'fake_platform',
|
||||
'fake_platform_with_streaming',
|
||||
'fake_platform_with_failure',
|
||||
'mock_platform_adapter',
|
||||
]
|
||||
|
||||
@@ -30,32 +30,36 @@ def _next_query_id() -> int:
|
||||
# ============== Message Chain Factories ==============
|
||||
|
||||
|
||||
def text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||
def text_chain(text: str = 'hello') -> platform_message.MessageChain:
|
||||
"""Create a simple text message chain."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text=text),
|
||||
])
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text=text),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
|
||||
def group_text_chain(text: str = 'hello') -> platform_message.MessageChain:
|
||||
"""Create a group text message chain (same as text_chain, context provided by event)."""
|
||||
return text_chain(text)
|
||||
|
||||
|
||||
def mention_chain(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
target: typing.Union[int, str] = 12345,
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a message chain with @mention."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.At(target=target),
|
||||
platform_message.Plain(text=f" {text}"),
|
||||
])
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.At(target=target),
|
||||
platform_message.Plain(text=f' {text}'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def image_chain(
|
||||
text: str = "",
|
||||
url: str = "https://example.com/image.png",
|
||||
text: str = '',
|
||||
url: str = 'https://example.com/image.png',
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a message chain with an image."""
|
||||
components = []
|
||||
@@ -66,13 +70,15 @@ def image_chain(
|
||||
|
||||
|
||||
def command_chain(
|
||||
command: str = "help",
|
||||
prefix: str = "/",
|
||||
command: str = 'help',
|
||||
prefix: str = '/',
|
||||
) -> platform_message.MessageChain:
|
||||
"""Create a command message chain."""
|
||||
return platform_message.MessageChain([
|
||||
platform_message.Plain(text=f"{prefix}{command}"),
|
||||
])
|
||||
return platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text=f'{prefix}{command}'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ============== Message Event Factories ==============
|
||||
@@ -81,7 +87,7 @@ def command_chain(
|
||||
def friend_message_event(
|
||||
message_chain: platform_message.MessageChain,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
nickname: str = "TestUser",
|
||||
nickname: str = 'TestUser',
|
||||
) -> platform_events.FriendMessage:
|
||||
"""Create a friend (private) message event."""
|
||||
sender = platform_entities.Friend(
|
||||
@@ -90,7 +96,7 @@ def friend_message_event(
|
||||
remark=None,
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
type='FriendMessage',
|
||||
sender=sender,
|
||||
message_chain=message_chain,
|
||||
time=1609459200,
|
||||
@@ -100,9 +106,9 @@ def friend_message_event(
|
||||
def group_message_event(
|
||||
message_chain: platform_message.MessageChain,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
sender_name: str = "TestUser",
|
||||
sender_name: str = 'TestUser',
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
group_name: str = "TestGroup",
|
||||
group_name: str = 'TestGroup',
|
||||
) -> platform_events.GroupMessage:
|
||||
"""Create a group message event."""
|
||||
group = platform_entities.Group(
|
||||
@@ -117,7 +123,7 @@ def group_message_event(
|
||||
group=group,
|
||||
)
|
||||
return platform_events.GroupMessage(
|
||||
type="GroupMessage",
|
||||
type='GroupMessage',
|
||||
sender=sender,
|
||||
message_chain=message_chain,
|
||||
time=1609459200,
|
||||
@@ -152,36 +158,36 @@ def _base_query(
|
||||
query_id = _next_query_id()
|
||||
|
||||
base_data = {
|
||||
"query_id": query_id,
|
||||
"launcher_type": launcher_type,
|
||||
"launcher_id": launcher_id,
|
||||
"sender_id": sender_id,
|
||||
"message_chain": message_chain,
|
||||
"message_event": message_event,
|
||||
"adapter": adapter,
|
||||
"pipeline_uuid": "test-pipeline-uuid",
|
||||
"bot_uuid": "test-bot-uuid",
|
||||
"pipeline_config": {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||
"prompt": "test-prompt",
|
||||
'query_id': query_id,
|
||||
'launcher_type': launcher_type,
|
||||
'launcher_id': launcher_id,
|
||||
'sender_id': sender_id,
|
||||
'message_chain': message_chain,
|
||||
'message_event': message_event,
|
||||
'adapter': adapter,
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
'bot_uuid': 'test-bot-uuid',
|
||||
'pipeline_config': {
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {
|
||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
||||
'prompt': 'test-prompt',
|
||||
},
|
||||
},
|
||||
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||
"trigger": {"misc": {"combine-quote-message": False}},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
},
|
||||
"session": None,
|
||||
"prompt": None,
|
||||
"messages": [],
|
||||
"user_message": None,
|
||||
"use_funcs": [],
|
||||
"use_llm_model_uuid": None,
|
||||
"variables": {},
|
||||
"resp_messages": [],
|
||||
"resp_message_chain": None,
|
||||
"current_stage_name": None,
|
||||
'session': None,
|
||||
'prompt': None,
|
||||
'messages': [],
|
||||
'user_message': None,
|
||||
'use_funcs': [],
|
||||
'use_llm_model_uuid': None,
|
||||
'variables': {},
|
||||
'resp_messages': [],
|
||||
'resp_message_chain': None,
|
||||
'current_stage_name': None,
|
||||
}
|
||||
|
||||
# Apply overrides
|
||||
@@ -192,7 +198,7 @@ def _base_query(
|
||||
|
||||
|
||||
def text_query(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -212,7 +218,7 @@ def text_query(
|
||||
|
||||
|
||||
def private_text_query(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -221,7 +227,7 @@ def private_text_query(
|
||||
|
||||
|
||||
def group_text_query(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
**overrides,
|
||||
@@ -242,8 +248,8 @@ def group_text_query(
|
||||
|
||||
|
||||
def command_query(
|
||||
command: str = "help",
|
||||
prefix: str = "/",
|
||||
command: str = 'help',
|
||||
prefix: str = '/',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -263,7 +269,7 @@ def command_query(
|
||||
|
||||
|
||||
def mention_query(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
target: typing.Union[int, str] = 12345,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
@@ -301,8 +307,8 @@ def empty_query(**overrides) -> pipeline_query.Query:
|
||||
|
||||
|
||||
def image_query(
|
||||
text: str = "",
|
||||
url: str = "https://example.com/image.png",
|
||||
text: str = '',
|
||||
url: str = 'https://example.com/image.png',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -322,9 +328,9 @@ def image_query(
|
||||
|
||||
|
||||
def file_query(
|
||||
url: str = "https://example.com/document.pdf",
|
||||
name: str = "document.pdf",
|
||||
text: str = "",
|
||||
url: str = 'https://example.com/document.pdf',
|
||||
name: str = 'document.pdf',
|
||||
text: str = '',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -348,8 +354,8 @@ def file_query(
|
||||
|
||||
|
||||
def unsupported_query(
|
||||
unsupported_type: str = "CustomComponent",
|
||||
text: str = "",
|
||||
unsupported_type: str = 'CustomComponent',
|
||||
text: str = '',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -358,7 +364,7 @@ def unsupported_query(
|
||||
if text:
|
||||
components.append(platform_message.Plain(text=text))
|
||||
# Use Unknown component for unsupported types
|
||||
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
|
||||
components.append(platform_message.Unknown(text=f'Unsupported: {unsupported_type}'))
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = friend_message_event(chain, sender_id)
|
||||
adapter = mock_adapter()
|
||||
@@ -374,7 +380,7 @@ def unsupported_query(
|
||||
|
||||
|
||||
def query_with_session(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
session: provider_session.Session = None,
|
||||
**overrides,
|
||||
@@ -389,7 +395,7 @@ def query_with_session(
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=sender_id,
|
||||
sender_id=sender_id,
|
||||
use_prompt_name="default",
|
||||
use_prompt_name='default',
|
||||
using_conversation=None,
|
||||
conversations=[],
|
||||
)
|
||||
@@ -398,7 +404,7 @@ def query_with_session(
|
||||
|
||||
|
||||
def query_with_config(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
pipeline_config: dict = None,
|
||||
**overrides,
|
||||
@@ -410,22 +416,22 @@ def query_with_config(
|
||||
"""
|
||||
if pipeline_config is None:
|
||||
pipeline_config = {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
||||
"prompt": "test-prompt",
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {
|
||||
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
|
||||
'prompt': 'test-prompt',
|
||||
},
|
||||
},
|
||||
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
||||
"trigger": {"misc": {"combine-quote-message": False}},
|
||||
'output': {'misc': {'at-sender': False, 'quote-origin': False}},
|
||||
'trigger': {'misc': {'combine-quote-message': False}},
|
||||
}
|
||||
|
||||
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
|
||||
|
||||
|
||||
def voice_query(
|
||||
url: str = "https://example.com/audio.mp3",
|
||||
url: str = 'https://example.com/audio.mp3',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
**overrides,
|
||||
) -> pipeline_query.Query:
|
||||
@@ -448,7 +454,7 @@ def voice_query(
|
||||
|
||||
|
||||
def at_all_query(
|
||||
text: str = "hello",
|
||||
text: str = 'hello',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
**overrides,
|
||||
@@ -456,7 +462,7 @@ def at_all_query(
|
||||
"""Create a group query with @All mention."""
|
||||
components = [
|
||||
platform_message.AtAll(),
|
||||
platform_message.Plain(text=f" {text}"),
|
||||
platform_message.Plain(text=f' {text}'),
|
||||
]
|
||||
chain = platform_message.MessageChain(components)
|
||||
event = group_message_event(chain, sender_id, group_id=group_id)
|
||||
@@ -469,4 +475,4 @@ def at_all_query(
|
||||
sender_id=sender_id,
|
||||
adapter=adapter,
|
||||
**overrides,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ class FakePlatform:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bot_account_id: str = "test-bot",
|
||||
bot_account_id: str = 'test-bot',
|
||||
stream_output_supported: bool = False,
|
||||
raise_error: Exception = None,
|
||||
):
|
||||
@@ -48,16 +48,16 @@ class FakePlatform:
|
||||
# Registered listeners
|
||||
self._listeners: dict = {}
|
||||
|
||||
def raises(self, error: Exception) -> "FakePlatform":
|
||||
def raises(self, error: Exception) -> 'FakePlatform':
|
||||
"""Configure platform to raise an error on send."""
|
||||
self._raise_error = error
|
||||
return self
|
||||
|
||||
def send_failure(self) -> "FakePlatform":
|
||||
def send_failure(self) -> 'FakePlatform':
|
||||
"""Configure platform to simulate send failure."""
|
||||
return self.raises(Exception("Platform send failure"))
|
||||
return self.raises(Exception('Platform send failure'))
|
||||
|
||||
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
|
||||
def supports_streaming(self, supported: bool = True) -> 'FakePlatform':
|
||||
"""Configure whether streaming output is supported."""
|
||||
self._stream_output_supported = supported
|
||||
return self
|
||||
@@ -89,7 +89,7 @@ class FakePlatform:
|
||||
self,
|
||||
text: str,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
nickname: str = "TestUser",
|
||||
nickname: str = 'TestUser',
|
||||
) -> platform_events.FriendMessage:
|
||||
"""Create an inbound friend (private) message event."""
|
||||
sender = platform_entities.Friend(
|
||||
@@ -97,11 +97,13 @@ class FakePlatform:
|
||||
nickname=nickname,
|
||||
remark=None,
|
||||
)
|
||||
chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text=text),
|
||||
])
|
||||
chain = platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text=text),
|
||||
]
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
type='FriendMessage',
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
@@ -111,9 +113,9 @@ class FakePlatform:
|
||||
self,
|
||||
text: str,
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
sender_name: str = "TestUser",
|
||||
sender_name: str = 'TestUser',
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
group_name: str = "TestGroup",
|
||||
group_name: str = 'TestGroup',
|
||||
mention_bot: bool = False,
|
||||
) -> platform_events.GroupMessage:
|
||||
"""Create an inbound group message event.
|
||||
@@ -142,12 +144,12 @@ class FakePlatform:
|
||||
components = []
|
||||
if mention_bot:
|
||||
components.append(platform_message.At(target=self.bot_account_id))
|
||||
components.append(platform_message.Plain(text=" "))
|
||||
components.append(platform_message.Plain(text=' '))
|
||||
components.append(platform_message.Plain(text=text))
|
||||
|
||||
chain = platform_message.MessageChain(components)
|
||||
return platform_events.GroupMessage(
|
||||
type="GroupMessage",
|
||||
type='GroupMessage',
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
@@ -155,8 +157,8 @@ class FakePlatform:
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
url: str = "https://example.com/image.png",
|
||||
text: str = "",
|
||||
url: str = 'https://example.com/image.png',
|
||||
text: str = '',
|
||||
sender_id: typing.Union[int, str] = 12345,
|
||||
is_group: bool = False,
|
||||
group_id: typing.Union[int, str] = 99999,
|
||||
@@ -169,12 +171,12 @@ class FakePlatform:
|
||||
chain = platform_message.MessageChain(components)
|
||||
|
||||
if is_group:
|
||||
return self.create_group_message("", sender_id, group_id=group_id)
|
||||
return self.create_group_message('', sender_id, group_id=group_id)
|
||||
# Replace chain
|
||||
else:
|
||||
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
|
||||
sender = platform_entities.Friend(id=sender_id, nickname='TestUser', remark=None)
|
||||
return platform_events.FriendMessage(
|
||||
type="FriendMessage",
|
||||
type='FriendMessage',
|
||||
sender=sender,
|
||||
message_chain=chain,
|
||||
time=1609459200,
|
||||
@@ -192,12 +194,14 @@ class FakePlatform:
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_messages.append({
|
||||
"type": "send",
|
||||
"target_type": target_type,
|
||||
"target_id": target_id,
|
||||
"message": message,
|
||||
})
|
||||
self._outbound_messages.append(
|
||||
{
|
||||
'type': 'send',
|
||||
'target_type': target_type,
|
||||
'target_id': target_id,
|
||||
'message': message,
|
||||
}
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -209,13 +213,15 @@ class FakePlatform:
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_messages.append({
|
||||
"type": "reply",
|
||||
"source_type": message_source.type,
|
||||
"source": message_source,
|
||||
"message": message,
|
||||
"quote_origin": quote_origin,
|
||||
})
|
||||
self._outbound_messages.append(
|
||||
{
|
||||
'type': 'reply',
|
||||
'source_type': message_source.type,
|
||||
'source': message_source,
|
||||
'message': message,
|
||||
'quote_origin': quote_origin,
|
||||
}
|
||||
)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
@@ -229,15 +235,17 @@ class FakePlatform:
|
||||
if self._raise_error:
|
||||
raise self._raise_error
|
||||
|
||||
self._outbound_chunks.append({
|
||||
"type": "reply_chunk",
|
||||
"source_type": message_source.type,
|
||||
"source": message_source,
|
||||
"bot_message": bot_message,
|
||||
"message": message,
|
||||
"quote_origin": quote_origin,
|
||||
"is_final": is_final,
|
||||
})
|
||||
self._outbound_chunks.append(
|
||||
{
|
||||
'type': 'reply_chunk',
|
||||
'source_type': message_source.type,
|
||||
'source': message_source,
|
||||
'bot_message': bot_message,
|
||||
'message': message,
|
||||
'quote_origin': quote_origin,
|
||||
'is_final': is_final,
|
||||
}
|
||||
)
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""Return whether streaming output is supported."""
|
||||
@@ -295,7 +303,7 @@ class FakePlatform:
|
||||
|
||||
|
||||
def fake_platform(
|
||||
bot_account_id: str = "test-bot",
|
||||
bot_account_id: str = 'test-bot',
|
||||
stream_output_supported: bool = False,
|
||||
) -> FakePlatform:
|
||||
"""Create a FakePlatform instance."""
|
||||
@@ -328,9 +336,7 @@ def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
|
||||
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
|
||||
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
|
||||
adapter.send_message = AsyncMock(side_effect=platform.send_message)
|
||||
adapter.is_stream_output_supported = AsyncMock(
|
||||
return_value=platform._stream_output_supported
|
||||
)
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=platform._stream_output_supported)
|
||||
adapter._fake_platform = platform # Store for assertions
|
||||
|
||||
return adapter
|
||||
return adapter
|
||||
|
||||
@@ -27,51 +27,51 @@ class FakeProvider:
|
||||
Does not require API keys.
|
||||
"""
|
||||
|
||||
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
|
||||
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_response: str = "fake response",
|
||||
default_response: str = 'fake response',
|
||||
streaming_chunks: list[str] = None,
|
||||
raise_error: Exception = None,
|
||||
captured_requests: list = None,
|
||||
):
|
||||
self._default_response = default_response
|
||||
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
|
||||
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
|
||||
self._raise_error = raise_error
|
||||
self._captured_requests = captured_requests if captured_requests is not None else []
|
||||
|
||||
def returns(self, text: str) -> "FakeProvider":
|
||||
def returns(self, text: str) -> 'FakeProvider':
|
||||
"""Configure provider to return a specific text response."""
|
||||
self._default_response = text
|
||||
self._streaming_chunks = [text]
|
||||
return self
|
||||
|
||||
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
|
||||
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
|
||||
"""Configure provider to return streaming chunks."""
|
||||
self._streaming_chunks = chunks
|
||||
self._default_response = "".join(chunks)
|
||||
self._default_response = ''.join(chunks)
|
||||
return self
|
||||
|
||||
def raises(self, error: Exception) -> "FakeProvider":
|
||||
def raises(self, error: Exception) -> 'FakeProvider':
|
||||
"""Configure provider to raise an error."""
|
||||
self._raise_error = error
|
||||
return self
|
||||
|
||||
def timeout(self) -> "FakeProvider":
|
||||
def timeout(self) -> 'FakeProvider':
|
||||
"""Configure provider to simulate timeout."""
|
||||
return self.raises(TimeoutError("Provider timeout"))
|
||||
return self.raises(TimeoutError('Provider timeout'))
|
||||
|
||||
def auth_error(self) -> "FakeProvider":
|
||||
def auth_error(self) -> 'FakeProvider':
|
||||
"""Configure provider to simulate auth error."""
|
||||
return self.raises(Exception("Invalid API key"))
|
||||
return self.raises(Exception('Invalid API key'))
|
||||
|
||||
def rate_limit(self) -> "FakeProvider":
|
||||
def rate_limit(self) -> 'FakeProvider':
|
||||
"""Configure provider to simulate rate limit."""
|
||||
return self.raises(Exception("Rate limit exceeded"))
|
||||
return self.raises(Exception('Rate limit exceeded'))
|
||||
|
||||
def malformed(self) -> "FakeProvider":
|
||||
def malformed(self) -> 'FakeProvider':
|
||||
"""Configure provider to simulate malformed response."""
|
||||
self._default_response = None
|
||||
return self
|
||||
@@ -87,7 +87,7 @@ class FakeProvider:
|
||||
def _create_message(self, content: str) -> provider_message.Message:
|
||||
"""Create a provider message from text content."""
|
||||
return provider_message.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ class FakeProvider:
|
||||
) -> provider_message.MessageChunk:
|
||||
"""Create a provider message chunk."""
|
||||
return provider_message.MessageChunk(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=content,
|
||||
is_final=is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
@@ -116,13 +116,15 @@ class FakeProvider:
|
||||
) -> provider_message.Message:
|
||||
"""Simulate non-streaming LLM invocation."""
|
||||
# Capture request for assertions
|
||||
self._captured_requests.append({
|
||||
"query_id": query.query_id if query else None,
|
||||
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
"messages": messages,
|
||||
"funcs": funcs,
|
||||
"extra_args": extra_args,
|
||||
})
|
||||
self._captured_requests.append(
|
||||
{
|
||||
'query_id': query.query_id if query else None,
|
||||
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
'messages': messages,
|
||||
'funcs': funcs,
|
||||
'extra_args': extra_args,
|
||||
}
|
||||
)
|
||||
|
||||
# Simulate error if configured
|
||||
if self._raise_error:
|
||||
@@ -131,7 +133,7 @@ class FakeProvider:
|
||||
# Return response
|
||||
if self._default_response is None:
|
||||
# Malformed response
|
||||
return provider_message.Message(role="assistant", content=None)
|
||||
return provider_message.Message(role='assistant', content=None)
|
||||
|
||||
return self._create_message(self._default_response)
|
||||
|
||||
@@ -146,14 +148,16 @@ class FakeProvider:
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""Simulate streaming LLM invocation."""
|
||||
# Capture request for assertions
|
||||
self._captured_requests.append({
|
||||
"query_id": query.query_id if query else None,
|
||||
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
"messages": messages,
|
||||
"funcs": funcs,
|
||||
"extra_args": extra_args,
|
||||
"streaming": True,
|
||||
})
|
||||
self._captured_requests.append(
|
||||
{
|
||||
'query_id': query.query_id if query else None,
|
||||
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
||||
'messages': messages,
|
||||
'funcs': funcs,
|
||||
'extra_args': extra_args,
|
||||
'streaming': True,
|
||||
}
|
||||
)
|
||||
|
||||
# Simulate error if configured
|
||||
if self._raise_error:
|
||||
@@ -161,12 +165,12 @@ class FakeProvider:
|
||||
|
||||
# Yield chunks
|
||||
for i, chunk in enumerate(self._streaming_chunks):
|
||||
is_final = (i == len(self._streaming_chunks) - 1)
|
||||
is_final = i == len(self._streaming_chunks) - 1
|
||||
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
|
||||
|
||||
|
||||
def fake_provider(
|
||||
default_response: str = "fake response",
|
||||
default_response: str = 'fake response',
|
||||
) -> FakeProvider:
|
||||
"""Create a FakeProvider with optional default response."""
|
||||
return FakeProvider(default_response=default_response)
|
||||
@@ -202,8 +206,8 @@ def fake_provider_malformed() -> FakeProvider:
|
||||
|
||||
def fake_model(
|
||||
*,
|
||||
uuid: str = "test-model-uuid",
|
||||
name: str = "test-model",
|
||||
uuid: str = 'test-model-uuid',
|
||||
name: str = 'test-model',
|
||||
abilities: list[str] = None,
|
||||
provider: FakeProvider = None,
|
||||
) -> Mock:
|
||||
@@ -212,7 +216,7 @@ def fake_model(
|
||||
model.model_entity = Mock()
|
||||
model.model_entity.uuid = uuid
|
||||
model.model_entity.name = name
|
||||
model.model_entity.abilities = abilities or ["func_call", "vision"]
|
||||
model.model_entity.abilities = abilities or ['func_call', 'vision']
|
||||
model.model_entity.extra_args = {}
|
||||
|
||||
# Attach fake provider
|
||||
@@ -221,4 +225,4 @@ def fake_model(
|
||||
|
||||
model.provider = provider
|
||||
|
||||
return model
|
||||
return model
|
||||
|
||||
@@ -3,4 +3,4 @@ Integration tests package.
|
||||
|
||||
These tests validate real system behavior with actual database/network resources.
|
||||
Run with: uv run pytest tests/integration/ -m "not slow" -q
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
API integration tests package.
|
||||
|
||||
Tests for HTTP API endpoints using Quart test client.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -48,6 +48,7 @@ def mock_circular_import_chain():
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -56,10 +57,12 @@ def fake_bot_app():
|
||||
"""Create FakeApp with bot services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
@@ -71,28 +74,29 @@ def fake_bot_app():
|
||||
|
||||
# Bot service
|
||||
app.bot_service = Mock()
|
||||
app.bot_service.get_bots = AsyncMock(return_value=[
|
||||
{
|
||||
app.bot_service.get_bots = AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
'uuid': 'test-bot-uuid',
|
||||
'name': 'Test Bot',
|
||||
'platform': 'telegram',
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
}
|
||||
]
|
||||
)
|
||||
app.bot_service.get_runtime_bot_info = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-bot-uuid',
|
||||
'name': 'Test Bot',
|
||||
'platform': 'telegram',
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
|
||||
}
|
||||
])
|
||||
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
|
||||
'uuid': 'test-bot-uuid',
|
||||
'name': 'Test Bot',
|
||||
'platform': 'telegram',
|
||||
'pipeline_uuid': 'test-pipeline-uuid',
|
||||
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
|
||||
})
|
||||
)
|
||||
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
|
||||
app.bot_service.update_bot = AsyncMock(return_value={})
|
||||
app.bot_service.delete_bot = AsyncMock()
|
||||
app.bot_service.list_event_logs = AsyncMock(return_value=(
|
||||
[{'uuid': 'log-1', 'message': 'test log'}],
|
||||
1
|
||||
))
|
||||
app.bot_service.list_event_logs = AsyncMock(return_value=([{'uuid': 'log-1', 'message': 'test log'}], 1))
|
||||
app.bot_service.send_message = AsyncMock()
|
||||
|
||||
# Platform manager
|
||||
@@ -118,10 +122,7 @@ class TestBotEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bots_success(self, quart_test_client):
|
||||
"""GET /api/v1/platform/bots returns bot list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/platform/bots',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
response = await quart_test_client.get('/api/v1/platform/bots', headers={'Authorization': 'Bearer test_token'})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -135,7 +136,7 @@ class TestBotEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
|
||||
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -147,8 +148,7 @@ class TestBotEndpoints:
|
||||
async def test_get_single_bot_success(self, quart_test_client):
|
||||
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -162,7 +162,7 @@ class TestBotEndpoints:
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated Bot'}
|
||||
json={'name': 'Updated Bot'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -173,8 +173,7 @@ class TestBotEndpoints:
|
||||
async def test_delete_bot_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/platform/bots/test-bot-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/platform/bots/test-bot-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -190,7 +189,7 @@ class TestBotLogsEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/logs',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'from_index': -1, 'max_count': 10}
|
||||
json={'from_index': -1, 'max_count': 10},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -213,8 +212,8 @@ class TestBotSendMessageEndpoint:
|
||||
json={
|
||||
'target_type': 'person',
|
||||
'target_id': 'user123',
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||
}
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -228,7 +227,7 @@ class TestBotSendMessageEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/platform/bots/test-bot-uuid/send_message',
|
||||
headers={'Authorization': 'Bearer test_api_key'},
|
||||
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
|
||||
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -244,8 +243,8 @@ class TestBotSendMessageEndpoint:
|
||||
json={
|
||||
'target_type': 'invalid',
|
||||
'target_id': 'user123',
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}]
|
||||
}
|
||||
'message_chain': [{'type': 'text', 'text': 'Hello'}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -47,6 +47,7 @@ def mock_circular_import_chain():
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -55,10 +56,12 @@ def fake_embed_app():
|
||||
"""Create FakeApp with embed widget services (module scope)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# Create mock web_page_bot with valid UUID format
|
||||
mock_bot_entity = Mock()
|
||||
@@ -83,9 +86,7 @@ def fake_embed_app():
|
||||
|
||||
# WebSocket proxy bot with adapter
|
||||
mock_websocket_adapter = Mock()
|
||||
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
|
||||
{'id': 'msg-1', 'content': 'test message'}
|
||||
])
|
||||
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[{'id': 'msg-1', 'content': 'test message'}])
|
||||
mock_websocket_adapter.reset_session = Mock()
|
||||
mock_websocket_adapter.handle_websocket_message = AsyncMock()
|
||||
|
||||
@@ -117,9 +118,7 @@ class TestEmbedWidgetEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_success(self, quart_test_client):
|
||||
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
|
||||
)
|
||||
response = await quart_test_client.get('/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js')
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'javascript' in response.content_type
|
||||
@@ -127,18 +126,14 @@ class TestEmbedWidgetEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
|
||||
"""GET widget.js with invalid UUID returns 400."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/invalid-uuid/widget.js'
|
||||
)
|
||||
response = await quart_test_client.get('/api/v1/embed/invalid-uuid/widget.js')
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_widget_js_bot_not_found(self, quart_test_client):
|
||||
"""GET widget.js for non-existent bot returns 404."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
|
||||
)
|
||||
response = await quart_test_client.get('/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js')
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -164,8 +159,7 @@ class TestEmbedTurnstileVerifyEndpoint:
|
||||
async def test_turnstile_verify_no_secret(self, quart_test_client):
|
||||
"""POST turnstile verify without secret returns dummy token."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||
json={'token': 'test-token'}
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={'token': 'test-token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -177,8 +171,7 @@ class TestEmbedTurnstileVerifyEndpoint:
|
||||
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
|
||||
"""POST turnstile verify with invalid UUID returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/invalid-uuid/turnstile/verify',
|
||||
json={'token': 'test-token'}
|
||||
'/api/v1/embed/invalid-uuid/turnstile/verify', json={'token': 'test-token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -187,8 +180,7 @@ class TestEmbedTurnstileVerifyEndpoint:
|
||||
async def test_turnstile_verify_missing_token(self, quart_test_client):
|
||||
"""POST turnstile verify without token returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
|
||||
json={}
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify', json={}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -203,7 +195,7 @@ class TestEmbedMessagesEndpoint:
|
||||
"""GET messages/person returns messages."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -216,7 +208,7 @@ class TestEmbedMessagesEndpoint:
|
||||
"""GET messages/group returns messages."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -226,7 +218,7 @@ class TestEmbedMessagesEndpoint:
|
||||
"""GET messages with invalid session_type returns 400."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -241,7 +233,7 @@ class TestEmbedResetEndpoint:
|
||||
"""POST reset/person resets session."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -252,8 +244,7 @@ class TestEmbedResetEndpoint:
|
||||
async def test_reset_session_invalid_uuid(self, quart_test_client):
|
||||
"""POST reset with invalid UUID returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/invalid-uuid/reset/person',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
'/api/v1/embed/invalid-uuid/reset/person', headers={'Authorization': 'Bearer 1234567890.dummy'}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -269,7 +260,7 @@ class TestEmbedFeedbackEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 1}
|
||||
json={'message_id': 'msg-123', 'feedback_type': 1},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -283,7 +274,7 @@ class TestEmbedFeedbackEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 2}
|
||||
json={'message_id': 'msg-123', 'feedback_type': 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -294,7 +285,7 @@ class TestEmbedFeedbackEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
|
||||
headers={'Authorization': 'Bearer 1234567890.dummy'},
|
||||
json={'message_id': 'msg-123', 'feedback_type': 99}
|
||||
json={'message_id': 'msg-123', 'feedback_type': 99},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -49,6 +49,7 @@ def mock_circular_import_chain():
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -57,10 +58,12 @@ def fake_knowledge_app():
|
||||
"""Create FakeApp with knowledge services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
@@ -72,33 +75,35 @@ def fake_knowledge_app():
|
||||
|
||||
# Knowledge service
|
||||
app.knowledge_service = Mock()
|
||||
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
|
||||
{
|
||||
app.knowledge_service.get_knowledge_bases = AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
'uuid': 'test-kb-uuid',
|
||||
'name': 'Test Knowledge Base',
|
||||
'description': 'Test KB description',
|
||||
'engine_plugin_id': 'test/engine',
|
||||
'created_at': '2024-01-01T00:00:00',
|
||||
'updated_at': '2024-01-01T00:00:00',
|
||||
}
|
||||
]
|
||||
)
|
||||
app.knowledge_service.get_knowledge_base = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'test-kb-uuid',
|
||||
'name': 'Test Knowledge Base',
|
||||
'description': 'Test KB description',
|
||||
'engine_plugin_id': 'test/engine',
|
||||
'created_at': '2024-01-01T00:00:00',
|
||||
'updated_at': '2024-01-01T00:00:00',
|
||||
}
|
||||
])
|
||||
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
|
||||
'uuid': 'test-kb-uuid',
|
||||
'name': 'Test Knowledge Base',
|
||||
'description': 'Test KB description',
|
||||
'engine_plugin_id': 'test/engine',
|
||||
})
|
||||
)
|
||||
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
|
||||
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
|
||||
app.knowledge_service.delete_knowledge_base = AsyncMock()
|
||||
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
|
||||
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
|
||||
])
|
||||
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(
|
||||
return_value=[{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}]
|
||||
)
|
||||
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
|
||||
app.knowledge_service.delete_file = AsyncMock()
|
||||
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
|
||||
{'content': 'test result', 'score': 0.95}
|
||||
])
|
||||
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[{'content': 'test result', 'score': 0.95}])
|
||||
|
||||
# RAG manager
|
||||
app.rag_mgr = Mock()
|
||||
@@ -124,8 +129,7 @@ class TestKnowledgeBaseEndpoints:
|
||||
async def test_get_knowledge_bases_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases returns knowledge base list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/knowledge/bases', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -140,7 +144,7 @@ class TestKnowledgeBaseEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
|
||||
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -152,8 +156,7 @@ class TestKnowledgeBaseEndpoints:
|
||||
async def test_get_single_knowledge_base_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -167,7 +170,7 @@ class TestKnowledgeBaseEndpoints:
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated KB'}
|
||||
json={'name': 'Updated KB'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -178,8 +181,7 @@ class TestKnowledgeBaseEndpoints:
|
||||
async def test_delete_knowledge_base_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/knowledge/bases/test-kb-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -193,8 +195,7 @@ class TestKnowledgeBaseFilesEndpoints:
|
||||
async def test_get_files_success(self, quart_test_client):
|
||||
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -208,7 +209,7 @@ class TestKnowledgeBaseFilesEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
|
||||
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -220,8 +221,7 @@ class TestKnowledgeBaseFilesEndpoints:
|
||||
async def test_delete_file_from_knowledge_base(self, quart_test_client):
|
||||
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -237,7 +237,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
|
||||
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -249,9 +249,7 @@ class TestKnowledgeBaseRetrieveEndpoint:
|
||||
async def test_retrieve_without_query_returns_error(self, quart_test_client):
|
||||
"""POST retrieve without query returns 400."""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={}
|
||||
'/api/v1/knowledge/bases/test-kb-uuid/retrieve', headers={'Authorization': 'Bearer test_token'}, json={}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -46,6 +46,7 @@ def mock_circular_import_chain():
|
||||
clear=clear,
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -54,10 +55,12 @@ def fake_monitoring_app():
|
||||
"""Create FakeApp with monitoring services (module scope)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
|
||||
app.user_service = Mock()
|
||||
@@ -67,40 +70,43 @@ def fake_monitoring_app():
|
||||
|
||||
# Monitoring service
|
||||
app.monitoring_service = Mock()
|
||||
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
|
||||
'total_messages': 100,
|
||||
'total_llm_calls': 50,
|
||||
'total_sessions': 20,
|
||||
'active_sessions': 5,
|
||||
'total_errors': 2,
|
||||
})
|
||||
app.monitoring_service.get_messages = AsyncMock(return_value=(
|
||||
[{'id': 'msg-1', 'content': 'test'}], 100
|
||||
))
|
||||
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
|
||||
[{'id': 'llm-1'}], 50
|
||||
))
|
||||
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
|
||||
[{'id': 'emb-1'}], 10
|
||||
))
|
||||
app.monitoring_service.get_sessions = AsyncMock(return_value=(
|
||||
[{'session_id': 'sess-1'}], 20
|
||||
))
|
||||
app.monitoring_service.get_errors = AsyncMock(return_value=(
|
||||
[{'id': 'err-1'}], 2
|
||||
))
|
||||
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
|
||||
'found': True,
|
||||
'session_id': 'sess-1',
|
||||
})
|
||||
app.monitoring_service.get_message_details = AsyncMock(return_value={
|
||||
'found': True,
|
||||
'message_id': 'msg-1',
|
||||
})
|
||||
app.monitoring_service.get_overview_metrics = AsyncMock(
|
||||
return_value={
|
||||
'total_messages': 100,
|
||||
'total_llm_calls': 50,
|
||||
'total_sessions': 20,
|
||||
'active_sessions': 5,
|
||||
'total_errors': 2,
|
||||
}
|
||||
)
|
||||
app.monitoring_service.get_messages = AsyncMock(return_value=([{'id': 'msg-1', 'content': 'test'}], 100))
|
||||
app.monitoring_service.get_llm_calls = AsyncMock(return_value=([{'id': 'llm-1'}], 50))
|
||||
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=([{'id': 'emb-1'}], 10))
|
||||
app.monitoring_service.get_traces = AsyncMock(return_value=([{'trace_id': 'trace-1'}], 1))
|
||||
app.monitoring_service.get_trace_details = AsyncMock(
|
||||
return_value={
|
||||
'found': True,
|
||||
'trace_id': 'trace-1',
|
||||
'trace': {'trace_id': 'trace-1'},
|
||||
'spans': [],
|
||||
}
|
||||
)
|
||||
app.monitoring_service.get_sessions = AsyncMock(return_value=([{'session_id': 'sess-1'}], 20))
|
||||
app.monitoring_service.get_errors = AsyncMock(return_value=([{'id': 'err-1'}], 2))
|
||||
app.monitoring_service.get_session_analysis = AsyncMock(
|
||||
return_value={
|
||||
'found': True,
|
||||
'session_id': 'sess-1',
|
||||
}
|
||||
)
|
||||
app.monitoring_service.get_message_details = AsyncMock(
|
||||
return_value={
|
||||
'found': True,
|
||||
'message_id': 'msg-1',
|
||||
}
|
||||
)
|
||||
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
|
||||
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
|
||||
[{'feedback_id': 'fb-1'}], 12
|
||||
))
|
||||
app.monitoring_service.get_feedback_list = AsyncMock(return_value=([{'feedback_id': 'fb-1'}], 12))
|
||||
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
|
||||
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
|
||||
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
|
||||
@@ -130,8 +136,7 @@ class TestMonitoringOverviewEndpoint:
|
||||
async def test_get_overview_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/overview returns metrics."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/overview',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/overview', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -147,8 +152,7 @@ class TestMonitoringMessagesEndpoint:
|
||||
async def test_get_messages_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/messages returns message list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/messages',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/messages', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -165,8 +169,7 @@ class TestMonitoringLLMCallsEndpoint:
|
||||
async def test_get_llm_calls_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/llm-calls."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/llm-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/llm-calls', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -180,8 +183,7 @@ class TestMonitoringEmbeddingCallsEndpoint:
|
||||
async def test_get_embedding_calls_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/embedding-calls."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/embedding-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/embedding-calls', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -195,8 +197,7 @@ class TestMonitoringSessionsEndpoint:
|
||||
async def test_get_sessions_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/sessions."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/sessions',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/sessions', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -210,8 +211,7 @@ class TestMonitoringErrorsEndpoint:
|
||||
async def test_get_errors_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/errors."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/errors',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/errors', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -225,13 +225,13 @@ class TestMonitoringAllDataEndpoint:
|
||||
async def test_get_all_data_success(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/data returns all data."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/data',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/data', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert 'overview' in data['data']
|
||||
assert 'traces' in data['data']
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
@@ -242,8 +242,7 @@ class TestMonitoringDetailsEndpoints:
|
||||
async def test_get_session_analysis(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/sessions/sess-1/analysis',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/sessions/sess-1/analysis', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -252,8 +251,16 @@ class TestMonitoringDetailsEndpoints:
|
||||
async def test_get_message_details(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/messages/{id}/details."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/messages/msg-1/details',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/messages/msg-1/details', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_trace_details(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/traces/{id}."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/traces/trace-1', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -267,8 +274,7 @@ class TestMonitoringFeedbackEndpoints:
|
||||
async def test_get_feedback_stats(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/feedback/stats."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/feedback/stats',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/feedback/stats', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -277,8 +283,7 @@ class TestMonitoringFeedbackEndpoints:
|
||||
async def test_get_feedback_list(self, quart_test_client):
|
||||
"""GET /api/v1/monitoring/feedback."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/feedback',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/feedback', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -292,8 +297,7 @@ class TestMonitoringExportEndpoint:
|
||||
async def test_export_messages(self, quart_test_client):
|
||||
"""GET export?type=messages returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=messages',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/export?type=messages', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -303,8 +307,7 @@ class TestMonitoringExportEndpoint:
|
||||
async def test_export_llm_calls(self, quart_test_client):
|
||||
"""GET export?type=llm-calls returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=llm-calls',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/export?type=llm-calls', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -313,8 +316,7 @@ class TestMonitoringExportEndpoint:
|
||||
async def test_export_sessions(self, quart_test_client):
|
||||
"""GET export?type=sessions returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=sessions',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/export?type=sessions', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -323,8 +325,7 @@ class TestMonitoringExportEndpoint:
|
||||
async def test_export_feedback(self, quart_test_client):
|
||||
"""GET export?type=feedback returns CSV."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/monitoring/export?type=feedback',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/monitoring/export?type=feedback', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -49,6 +49,7 @@ def mock_circular_import_chain():
|
||||
):
|
||||
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
|
||||
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -57,10 +58,12 @@ def fake_provider_app():
|
||||
"""Create FakeApp with provider/model services (module scope for reuse)."""
|
||||
app = FakeApp()
|
||||
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# Auth services
|
||||
app.user_service = Mock()
|
||||
@@ -72,27 +75,23 @@ def fake_provider_app():
|
||||
|
||||
# Provider service
|
||||
app.provider_service = Mock()
|
||||
app.provider_service.get_providers = AsyncMock(return_value=[
|
||||
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
|
||||
])
|
||||
app.provider_service.get_provider = AsyncMock(return_value={
|
||||
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
|
||||
})
|
||||
app.provider_service.get_providers = AsyncMock(
|
||||
return_value=[{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}]
|
||||
)
|
||||
app.provider_service.get_provider = AsyncMock(
|
||||
return_value={'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
|
||||
)
|
||||
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
|
||||
app.provider_service.update_provider = AsyncMock(return_value={})
|
||||
app.provider_service.delete_provider = AsyncMock()
|
||||
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
|
||||
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
|
||||
})
|
||||
app.provider_service.get_provider_model_counts = AsyncMock(
|
||||
return_value={'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0}
|
||||
)
|
||||
|
||||
# LLM model service
|
||||
app.llm_model_service = Mock()
|
||||
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
|
||||
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
|
||||
])
|
||||
app.llm_model_service.get_llm_model = AsyncMock(return_value={
|
||||
'uuid': 'test-model-uuid', 'name': 'gpt-4'
|
||||
})
|
||||
app.llm_model_service.get_llm_models = AsyncMock(return_value=[{'uuid': 'test-model-uuid', 'name': 'gpt-4'}])
|
||||
app.llm_model_service.get_llm_model = AsyncMock(return_value={'uuid': 'test-model-uuid', 'name': 'gpt-4'})
|
||||
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
|
||||
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
|
||||
app.llm_model_service.delete_llm_model = AsyncMock()
|
||||
@@ -133,8 +132,7 @@ class TestProviderEndpoints:
|
||||
async def test_get_providers_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/providers returns provider list with complete structure."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/providers', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -157,8 +155,7 @@ class TestProviderEndpoints:
|
||||
async def test_get_single_provider_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -177,7 +174,7 @@ class TestProviderEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/providers',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Provider', 'requester': 'chatcmpl'}
|
||||
json={'name': 'New Provider', 'requester': 'chatcmpl'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -194,7 +191,7 @@ class TestProviderEndpoints:
|
||||
response = await quart_test_client.put(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'Updated Provider'}
|
||||
json={'name': 'Updated Provider'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -205,8 +202,7 @@ class TestProviderEndpoints:
|
||||
async def test_delete_provider_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -215,8 +211,7 @@ class TestProviderEndpoints:
|
||||
async def test_get_provider_includes_model_counts(self, quart_test_client):
|
||||
"""GET provider response includes model counts."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/providers/test-provider-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/providers/test-provider-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -237,8 +232,7 @@ class TestModelEndpoints:
|
||||
async def test_get_llm_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/llm returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/llm',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/models/llm', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -250,8 +244,7 @@ class TestModelEndpoints:
|
||||
async def test_get_single_llm_model_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/llm/test-model-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -264,7 +257,7 @@ class TestModelEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/llm',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -276,8 +269,7 @@ class TestModelEndpoints:
|
||||
async def test_delete_llm_model_success(self, quart_test_client):
|
||||
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
|
||||
response = await quart_test_client.delete(
|
||||
'/api/v1/provider/models/llm/test-model-uuid',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/models/llm/test-model-uuid', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -291,8 +283,7 @@ class TestEmbeddingModelEndpoints:
|
||||
async def test_get_embedding_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/embedding returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/embedding',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/models/embedding', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -306,7 +297,7 @@ class TestEmbeddingModelEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/embedding',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -323,8 +314,7 @@ class TestRerankModelEndpoints:
|
||||
async def test_get_rerank_models_success(self, quart_test_client):
|
||||
"""GET /api/v1/provider/models/rerank returns model list."""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/provider/models/rerank',
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
'/api/v1/provider/models/rerank', headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -338,7 +328,7 @@ class TestRerankModelEndpoints:
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/provider/models/rerank',
|
||||
headers={'Authorization': 'Bearer test_token'},
|
||||
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
|
||||
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -20,6 +20,7 @@ pytestmark = pytest.mark.integration
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
@@ -69,6 +70,7 @@ def mock_circular_import_chain():
|
||||
|
||||
# ============== FAKE APPLICATION FOR API TESTS ==============
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_api_app():
|
||||
"""
|
||||
@@ -79,12 +81,14 @@ def fake_api_app():
|
||||
app = FakeApp()
|
||||
|
||||
# API-specific config
|
||||
app.instance_config.data.update({
|
||||
'api': {'port': 5300},
|
||||
'plugin': {'enable_marketplace': True},
|
||||
'space': {'url': 'https://space.langbot.app'},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
})
|
||||
app.instance_config.data.update(
|
||||
{
|
||||
'api': {'port': 5300},
|
||||
'plugin': {'enable_marketplace': True},
|
||||
'space': {'url': 'https://space.langbot.app'},
|
||||
'system': {'allow_modify_login_info': True, 'limitation': {}},
|
||||
}
|
||||
)
|
||||
|
||||
# API-specific services
|
||||
app.user_service = Mock()
|
||||
@@ -118,6 +122,7 @@ def fake_api_app():
|
||||
|
||||
# ============== QUART TEST CLIENT FIXTURE ==============
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def quart_test_client(fake_api_app, http_controller_cls):
|
||||
"""
|
||||
@@ -135,6 +140,7 @@ async def quart_test_client(fake_api_app, http_controller_cls):
|
||||
|
||||
# ============== API SMOKE TESTS ==============
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /healthz endpoint - simplest smoke test."""
|
||||
@@ -222,8 +228,7 @@ class TestProtectedEndpoints:
|
||||
Protected endpoint returns 401 with invalid token.
|
||||
"""
|
||||
response = await quart_test_client.get(
|
||||
'/api/v1/user/check-token',
|
||||
headers={'Authorization': 'Bearer invalid_token'}
|
||||
'/api/v1/user/check-token', headers={'Authorization': 'Bearer invalid_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -254,10 +259,7 @@ class TestInvalidPayload:
|
||||
"""
|
||||
POST with wrong JSON structure returns stable error.
|
||||
"""
|
||||
response = await quart_test_client.post(
|
||||
'/api/v1/user/auth',
|
||||
json={'wrong_field': 'value'}
|
||||
)
|
||||
response = await quart_test_client.post('/api/v1/user/auth', json={'wrong_field': 'value'})
|
||||
|
||||
# Should return error with stable JSON structure
|
||||
assert response.status_code in (400, 500, 401)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
Persistence integration tests package.
|
||||
|
||||
Tests for database migrations and storage behavior.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -26,8 +26,8 @@ pytestmark = pytest.mark.integration
|
||||
@pytest.fixture
|
||||
def sqlite_db_url(tmp_path):
|
||||
"""Create SQLite URL with temporary database file."""
|
||||
db_file = tmp_path / "test_migrations.db"
|
||||
return f"sqlite+aiosqlite:///{db_file}"
|
||||
db_file = tmp_path / 'test_migrations.db'
|
||||
return f'sqlite+aiosqlite:///{db_file}'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -102,9 +102,9 @@ class TestSQLiteMigrationUpgrade:
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
assert rev is not None, 'Expected a revision after upgrade'
|
||||
# Head should be the latest migration
|
||||
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||
assert rev.startswith('0006'), f'Expected head to be 0006_*, got {rev}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_idempotent(self, sqlite_engine):
|
||||
@@ -131,7 +131,7 @@ class TestSQLiteMigrationUpgrade:
|
||||
await run_alembic_upgrade(sqlite_engine, 'head')
|
||||
|
||||
rev2 = await get_alembic_current(sqlite_engine)
|
||||
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
|
||||
|
||||
|
||||
class TestSQLiteMigrationFreshDatabase:
|
||||
@@ -149,8 +149,8 @@ class TestSQLiteMigrationFreshDatabase:
|
||||
4. Verify revision
|
||||
"""
|
||||
# Use different DB file for fresh test
|
||||
fresh_db_file = tmp_path / "test_migrations_fresh.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_db_file = tmp_path / 'test_migrations_fresh.db'
|
||||
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Create tables on fresh DB
|
||||
@@ -162,7 +162,7 @@ class TestSQLiteMigrationFreshDatabase:
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(fresh_engine)
|
||||
assert rev is not None, "Expected a revision on fresh DB"
|
||||
assert rev is not None, 'Expected a revision on fresh DB'
|
||||
|
||||
await fresh_engine.dispose()
|
||||
|
||||
@@ -181,8 +181,8 @@ class TestSQLiteMigrationFreshDatabase:
|
||||
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
|
||||
any arbitrary failure with try-except pass.
|
||||
"""
|
||||
fresh_db_file = tmp_path / "test_empty_migrations.db"
|
||||
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
|
||||
fresh_db_file = tmp_path / 'test_empty_migrations.db'
|
||||
fresh_url = f'sqlite+aiosqlite:///{fresh_db_file}'
|
||||
fresh_engine = create_async_engine(fresh_url)
|
||||
|
||||
# Capture the actual behavior
|
||||
@@ -201,23 +201,23 @@ class TestSQLiteMigrationFreshDatabase:
|
||||
# Verify specific behavior - one of two outcomes is expected
|
||||
if actual_result is not None:
|
||||
# Migration succeeded - verify revision exists
|
||||
assert actual_result is not None, "Revision should exist after successful migration"
|
||||
assert actual_result is not None, 'Revision should exist after successful migration'
|
||||
else:
|
||||
# Migration failed - verify the error type is known
|
||||
# Alembic typically raises specific errors for missing tables
|
||||
assert actual_error is not None, "Error should be captured if migration failed"
|
||||
assert actual_error is not None, 'Error should be captured if migration failed'
|
||||
# Log the error type for documentation (don't silently pass)
|
||||
error_type = type(actual_error).__name__
|
||||
# Acceptable error types for empty DB scenarios
|
||||
acceptable_errors = [
|
||||
'OperationalError', # SQLite table not found
|
||||
'ProgrammingError', # SQLAlchemy errors
|
||||
'CommandError', # Alembic command errors
|
||||
'CommandError', # Alembic command errors
|
||||
]
|
||||
assert error_type in acceptable_errors, (
|
||||
f"Unexpected error type: {error_type}. "
|
||||
f"This may indicate a regression in migration behavior. "
|
||||
f"Error: {actual_error}"
|
||||
f'Unexpected error type: {error_type}. '
|
||||
f'This may indicate a regression in migration behavior. '
|
||||
f'Error: {actual_error}'
|
||||
)
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ class TestSQLiteMigrationGetCurrent:
|
||||
|
||||
# No stamp - should return None
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||
assert rev is None, f'Expected None for unstamped DB, got {rev}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
|
||||
@@ -248,4 +248,4 @@ class TestSQLiteMigrationGetCurrent:
|
||||
await run_alembic_stamp(sqlite_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev == '0001_baseline'
|
||||
assert rev == '0001_baseline'
|
||||
|
||||
@@ -34,14 +34,14 @@ def postgres_url():
|
||||
"""Get PostgreSQL URL from environment."""
|
||||
url = os.environ.get('TEST_POSTGRES_URL')
|
||||
if not url:
|
||||
pytest.skip("TEST_POSTGRES_URL not set")
|
||||
pytest.skip('TEST_POSTGRES_URL not set')
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def postgres_engine(postgres_url):
|
||||
"""Create async PostgreSQL engine."""
|
||||
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
|
||||
engine = create_async_engine(postgres_url, isolation_level='AUTOCOMMIT')
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@@ -66,7 +66,7 @@ async def clean_alembic_version(postgres_engine):
|
||||
async with postgres_engine.begin() as conn:
|
||||
# Drop alembic_version table if exists
|
||||
try:
|
||||
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -74,7 +74,7 @@ async def clean_alembic_version(postgres_engine):
|
||||
|
||||
async with postgres_engine.begin() as conn:
|
||||
try:
|
||||
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||
await conn.execute(text('DROP TABLE IF EXISTS alembic_version'))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -83,9 +83,7 @@ class TestPostgreSQLMigrationBaseline:
|
||||
"""Tests for baseline stamp workflow on PostgreSQL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_baseline_stamp_sets_revision(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
async def test_postgres_baseline_stamp_sets_revision(self, postgres_engine, clean_tables, clean_alembic_version):
|
||||
"""
|
||||
Stamp baseline on existing tables sets correct revision.
|
||||
|
||||
@@ -106,9 +104,7 @@ class TestPostgreSQLMigrationBaseline:
|
||||
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_baseline_stamp_on_empty_db(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
async def test_postgres_baseline_stamp_on_empty_db(self, postgres_engine, clean_tables, clean_alembic_version):
|
||||
"""
|
||||
Stamp on empty database (no tables) still sets revision.
|
||||
|
||||
@@ -125,9 +121,7 @@ class TestPostgreSQLMigrationUpgrade:
|
||||
"""Tests for upgrade to head workflow on PostgreSQL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_upgrade_from_baseline_to_head(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
async def test_postgres_upgrade_from_baseline_to_head(self, postgres_engine, clean_tables, clean_alembic_version):
|
||||
"""
|
||||
Upgrade from baseline to head applies all migrations.
|
||||
|
||||
@@ -149,14 +143,12 @@ class TestPostgreSQLMigrationUpgrade:
|
||||
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration (0005 for current state)
|
||||
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||
assert rev is not None, 'Expected a revision after upgrade'
|
||||
# Head should be the latest migration.
|
||||
assert rev.startswith('0006'), f'Expected head to be 0006_*, got {rev}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_upgrade_idempotent(
|
||||
self, postgres_engine, clean_tables, clean_alembic_version
|
||||
):
|
||||
async def test_postgres_upgrade_idempotent(self, postgres_engine, clean_tables, clean_alembic_version):
|
||||
"""
|
||||
Running upgrade to head multiple times is idempotent.
|
||||
|
||||
@@ -180,7 +172,7 @@ class TestPostgreSQLMigrationUpgrade:
|
||||
await run_alembic_upgrade(postgres_engine, 'head')
|
||||
|
||||
rev2 = await get_alembic_current(postgres_engine)
|
||||
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
|
||||
assert rev2 == rev1, f'Expected {rev1}, got {rev2}'
|
||||
|
||||
|
||||
class TestPostgreSQLMigrationGetCurrent:
|
||||
@@ -199,7 +191,7 @@ class TestPostgreSQLMigrationGetCurrent:
|
||||
|
||||
# No stamp - should return None
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev is None, f"Expected None for unstamped DB, got {rev}"
|
||||
assert rev is None, f'Expected None for unstamped DB, got {rev}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_get_current_after_stamp_returns_revision(
|
||||
@@ -214,4 +206,4 @@ class TestPostgreSQLMigrationGetCurrent:
|
||||
await run_alembic_stamp(postgres_engine, '0001_baseline')
|
||||
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev == '0001_baseline'
|
||||
assert rev == '0001_baseline'
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
Pipeline integration tests package.
|
||||
|
||||
Tests for full pipeline flow using fake provider/runner.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,7 @@ pytestmark = pytest.mark.integration
|
||||
|
||||
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
@@ -103,6 +104,7 @@ def mock_circular_import_chain():
|
||||
|
||||
# ============== FAKE RUNNER ==============
|
||||
|
||||
|
||||
class FakeRunner:
|
||||
"""Minimal fake runner class for pipeline integration tests.
|
||||
|
||||
@@ -117,12 +119,13 @@ class FakeRunner:
|
||||
self.config = config or {}
|
||||
self._provider = FakeProvider()
|
||||
# Instance-level configuration set via class attribute
|
||||
self._response_text = "fake response"
|
||||
self._response_text = 'fake response'
|
||||
self._raise_error = None
|
||||
|
||||
@classmethod
|
||||
def returns(cls, text: str):
|
||||
"""Create a runner class configured to return specific text."""
|
||||
|
||||
# We create a subclass with configured response
|
||||
class ConfiguredRunner(cls):
|
||||
name = cls.name
|
||||
@@ -132,11 +135,13 @@ class FakeRunner:
|
||||
def __init__(self, app=None, config=None):
|
||||
super().__init__(app, config)
|
||||
self._response_text = text
|
||||
|
||||
return ConfiguredRunner
|
||||
|
||||
@classmethod
|
||||
def raises(cls, error: Exception):
|
||||
"""Create a runner class configured to raise an error."""
|
||||
|
||||
class ConfiguredRunner(cls):
|
||||
name = cls.name
|
||||
_response_text = None
|
||||
@@ -145,6 +150,7 @@ class FakeRunner:
|
||||
def __init__(self, app=None, config=None):
|
||||
super().__init__(app, config)
|
||||
self._raise_error = error
|
||||
|
||||
return ConfiguredRunner
|
||||
|
||||
async def run(self, query):
|
||||
@@ -161,6 +167,7 @@ class FakeRunner:
|
||||
|
||||
# ============== PIPELINE APP FIXTURE ==============
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_app():
|
||||
"""
|
||||
@@ -187,6 +194,7 @@ def pipeline_app():
|
||||
def __init__(self, name, messages):
|
||||
self.name = name
|
||||
self.messages = messages
|
||||
|
||||
def copy(self):
|
||||
return MockPrompt(self.name, list(self.messages))
|
||||
|
||||
@@ -237,14 +245,17 @@ def fake_platform_adapter():
|
||||
@pytest.fixture
|
||||
def set_fake_runner():
|
||||
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
|
||||
|
||||
def _set_runner(runner_cls):
|
||||
# preregistered_runners expects a list of runner classes
|
||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
|
||||
|
||||
return _set_runner
|
||||
|
||||
|
||||
# ============== PIPELINE CONFIGURATION ==============
|
||||
|
||||
|
||||
def create_minimal_pipeline_config():
|
||||
"""Create minimal pipeline configuration for tests."""
|
||||
return {
|
||||
@@ -273,6 +284,7 @@ def create_minimal_pipeline_config():
|
||||
|
||||
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
|
||||
|
||||
|
||||
async def collect_processor_results(processor, query, stage_name):
|
||||
"""
|
||||
Helper to handle the coroutine -> async_generator pattern.
|
||||
@@ -296,6 +308,7 @@ async def collect_processor_results(processor, query, stage_name):
|
||||
|
||||
# ============== TESTS ==============
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestPipelineStageChainReal:
|
||||
"""Tests for real pipeline stage chain."""
|
||||
@@ -337,7 +350,7 @@ class TestPreProcessorStage:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query with adapter
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
@@ -365,7 +378,7 @@ class TestPreProcessorStage:
|
||||
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
query = text_query("test message content")
|
||||
query = text_query('test message content')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
@@ -396,11 +409,11 @@ class TestProcessorStage:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that returns pong
|
||||
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
query.resp_messages = []
|
||||
@@ -414,6 +427,7 @@ class TestProcessorStage:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -432,7 +446,7 @@ class TestProcessorStage:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
@@ -445,6 +459,7 @@ class TestProcessorStage:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -462,13 +477,13 @@ class TestProcessorStage:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
query.resp_messages = []
|
||||
|
||||
# Create reply chain
|
||||
reply_chain = text_chain("plugin response")
|
||||
reply_chain = text_chain('plugin response')
|
||||
|
||||
# Mock plugin_connector to prevent default with reply
|
||||
mock_event_ctx = Mock()
|
||||
@@ -479,6 +494,7 @@ class TestProcessorStage:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -502,7 +518,7 @@ class TestRunnerExceptionFlow:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises exception
|
||||
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
|
||||
fake_runner = FakeRunner().raises(ValueError('API Error: rate limit exceeded'))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with exception handling config
|
||||
@@ -510,7 +526,7 @@ class TestRunnerExceptionFlow:
|
||||
config['output']['misc']['exception-handling'] = 'show-hint'
|
||||
config['output']['misc']['failure-hint'] = 'Request failed.'
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
@@ -523,6 +539,7 @@ class TestRunnerExceptionFlow:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -541,14 +558,14 @@ class TestRunnerExceptionFlow:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises specific exception
|
||||
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
|
||||
fake_runner = FakeRunner().raises(RuntimeError('Custom runtime error'))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with show-error mode
|
||||
config = create_minimal_pipeline_config()
|
||||
config['output']['misc']['exception-handling'] = 'show-error'
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
@@ -561,6 +578,7 @@ class TestRunnerExceptionFlow:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -578,14 +596,14 @@ class TestRunnerExceptionFlow:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner that raises exception
|
||||
fake_runner = FakeRunner().raises(Exception("Hidden error"))
|
||||
fake_runner = FakeRunner().raises(Exception('Hidden error'))
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query with hide mode
|
||||
config = create_minimal_pipeline_config()
|
||||
config['output']['misc']['exception-handling'] = 'hide'
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
|
||||
@@ -598,6 +616,7 @@ class TestRunnerExceptionFlow:
|
||||
|
||||
# Create Processor stage
|
||||
from langbot.pkg.pipeline.process import process
|
||||
|
||||
processor_stage = process.Processor(pipeline_app)
|
||||
await processor_stage.initialize(query.pipeline_config)
|
||||
|
||||
@@ -623,7 +642,7 @@ class TestSendResponseBackStage:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query with response message
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
@@ -666,12 +685,12 @@ class TestStageChainIntegration:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Set fake runner
|
||||
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
|
||||
fake_runner = FakeRunner().returns('LANGBOT_FAKE_PONG')
|
||||
set_fake_runner(fake_runner)
|
||||
|
||||
# Create query
|
||||
config = create_minimal_pipeline_config()
|
||||
query = text_query("ping")
|
||||
query = text_query('ping')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = config
|
||||
query.resp_messages = []
|
||||
@@ -690,7 +709,7 @@ class TestStageChainIntegration:
|
||||
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||
]
|
||||
|
||||
@@ -711,6 +730,7 @@ class TestStageChainIntegration:
|
||||
|
||||
# Build resp_message_chain from resp_messages
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
for resp_msg in query.resp_messages:
|
||||
if resp_msg.content:
|
||||
query.resp_message_chain.append(text_chain(resp_msg.content))
|
||||
@@ -737,7 +757,7 @@ class TestStageChainIntegration:
|
||||
adapter, platform = fake_platform_adapter
|
||||
|
||||
# Create query
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.adapter = adapter
|
||||
query.pipeline_config = create_minimal_pipeline_config()
|
||||
|
||||
@@ -754,7 +774,7 @@ class TestStageChainIntegration:
|
||||
|
||||
pipeline_app.plugin_connector.emit_event = AsyncMock()
|
||||
pipeline_app.plugin_connector.emit_event.side_effect = [
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
|
||||
mock_event_ctx_processor, # Processor NormalMessageReceived
|
||||
]
|
||||
|
||||
@@ -775,4 +795,4 @@ class TestStageChainIntegration:
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
|
||||
# Chain stops here - no resp_messages
|
||||
assert len(query.resp_messages) == 0
|
||||
assert len(query.resp_messages) == 0
|
||||
|
||||
@@ -3,4 +3,4 @@ Smoke tests package.
|
||||
|
||||
Smoke tests verify basic functionality works without testing edge cases.
|
||||
Run with: uv run pytest tests/smoke/ -q
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -39,19 +39,19 @@ class TestFakeMessageFlow:
|
||||
assert app.instance_config is not None
|
||||
|
||||
# Verify default config
|
||||
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
|
||||
assert app.instance_config.data["command"]["enable"] is True
|
||||
assert app.instance_config.data['command']['prefix'] == ['/', '!']
|
||||
assert app.instance_config.data['command']['enable'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_returns_text(self):
|
||||
"""Test FakeProvider returns configured response."""
|
||||
provider = FakeProvider(default_response="test response")
|
||||
provider = FakeProvider(default_response='test response')
|
||||
|
||||
# Create mock model with provider
|
||||
model = fake_model(provider=provider)
|
||||
|
||||
# Create a simple query
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
# Simulate invoke
|
||||
result = await provider.invoke_llm(
|
||||
@@ -63,15 +63,15 @@ class TestFakeMessageFlow:
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.role == "assistant"
|
||||
assert result.content == "test response"
|
||||
assert result.role == 'assistant'
|
||||
assert result.content == 'test response'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_pong(self):
|
||||
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
|
||||
provider = fake_provider_pong()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("ping")
|
||||
query = text_query('ping')
|
||||
|
||||
result = await provider.invoke_llm(
|
||||
query=query,
|
||||
@@ -86,9 +86,9 @@ class TestFakeMessageFlow:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_provider_streaming(self):
|
||||
"""Test FakeProvider streaming response."""
|
||||
provider = FakeProvider().returns_streaming(["Hello", " World"])
|
||||
provider = FakeProvider().returns_streaming(['Hello', ' World'])
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
chunks = []
|
||||
# invoke_llm_stream returns an async generator, don't await it
|
||||
@@ -102,8 +102,8 @@ class TestFakeMessageFlow:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == "Hello"
|
||||
assert chunks[1].content == " World"
|
||||
assert chunks[0].content == 'Hello'
|
||||
assert chunks[1].content == ' World'
|
||||
assert chunks[1].is_final is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -111,9 +111,9 @@ class TestFakeMessageFlow:
|
||||
"""Test FakeProvider simulates timeout error."""
|
||||
provider = FakeProvider().timeout()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
with pytest.raises(TimeoutError, match="Provider timeout"):
|
||||
with pytest.raises(TimeoutError, match='Provider timeout'):
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
@@ -127,9 +127,9 @@ class TestFakeMessageFlow:
|
||||
"""Test FakeProvider simulates rate limit error."""
|
||||
provider = FakeProvider().rate_limit()
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
with pytest.raises(Exception, match='Rate limit exceeded'):
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
@@ -142,34 +142,34 @@ class TestFakeMessageFlow:
|
||||
async def test_fake_provider_captures_requests(self):
|
||||
"""Test FakeProvider captures request arguments."""
|
||||
provider = FakeProvider()
|
||||
model = fake_model(name="gpt-4", provider=provider)
|
||||
query = text_query("hello")
|
||||
model = fake_model(name='gpt-4', provider=provider)
|
||||
query = text_query('hello')
|
||||
|
||||
await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
funcs=[{"name": "test_func"}],
|
||||
extra_args={"temperature": 0.7},
|
||||
messages=[{'role': 'user', 'content': 'hello'}],
|
||||
funcs=[{'name': 'test_func'}],
|
||||
extra_args={'temperature': 0.7},
|
||||
)
|
||||
|
||||
captured = provider.get_captured_requests()
|
||||
assert len(captured) == 1
|
||||
assert captured[0]["model"] == "gpt-4"
|
||||
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
|
||||
assert captured[0]["funcs"] == [{"name": "test_func"}]
|
||||
assert captured[0]["extra_args"] == {"temperature": 0.7}
|
||||
assert captured[0]['model'] == 'gpt-4'
|
||||
assert captured[0]['messages'] == [{'role': 'user', 'content': 'hello'}]
|
||||
assert captured[0]['funcs'] == [{'name': 'test_func'}]
|
||||
assert captured[0]['extra_args'] == {'temperature': 0.7}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_capture_outbound(self):
|
||||
"""Test FakePlatform captures outbound messages."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
query = text_query("hello")
|
||||
platform = FakePlatform(bot_account_id='test-bot')
|
||||
query = text_query('hello')
|
||||
|
||||
# Simulate sending reply
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
reply_chain = text_chain("response text")
|
||||
reply_chain = text_chain('response text')
|
||||
event = query.message_event
|
||||
|
||||
await platform.reply_message(event, reply_chain, quote_origin=False)
|
||||
@@ -177,38 +177,38 @@ class TestFakeMessageFlow:
|
||||
# Verify captured
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) == 1
|
||||
assert outbound[0]["type"] == "reply"
|
||||
assert outbound[0]["message"] == reply_chain
|
||||
assert outbound[0]['type'] == 'reply'
|
||||
assert outbound[0]['message'] == reply_chain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_friend_message(self):
|
||||
"""Test FakePlatform creates friend message events."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
platform = FakePlatform(bot_account_id='test-bot')
|
||||
|
||||
event = platform.create_friend_message(
|
||||
text="hello bot",
|
||||
text='hello bot',
|
||||
sender_id=12345,
|
||||
nickname="TestUser",
|
||||
nickname='TestUser',
|
||||
)
|
||||
|
||||
assert event.type == "FriendMessage"
|
||||
assert event.type == 'FriendMessage'
|
||||
assert event.sender.id == 12345
|
||||
assert event.sender.nickname == "TestUser"
|
||||
assert str(event.message_chain) == "hello bot"
|
||||
assert event.sender.nickname == 'TestUser'
|
||||
assert str(event.message_chain) == 'hello bot'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_group_message_with_mention(self):
|
||||
"""Test FakePlatform creates group message with @mention."""
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
platform = FakePlatform(bot_account_id='test-bot')
|
||||
|
||||
event = platform.create_group_message(
|
||||
text="hello everyone",
|
||||
text='hello everyone',
|
||||
sender_id=12345,
|
||||
group_id=99999,
|
||||
mention_bot=True,
|
||||
)
|
||||
|
||||
assert event.type == "GroupMessage"
|
||||
assert event.type == 'GroupMessage'
|
||||
assert event.sender.id == 12345
|
||||
assert event.group.id == 99999
|
||||
|
||||
@@ -220,54 +220,57 @@ class TestFakeMessageFlow:
|
||||
async def test_query_factories_basic(self):
|
||||
"""Test basic query factory functions."""
|
||||
# Text query
|
||||
q1 = text_query("hello world")
|
||||
assert q1.launcher_type.value == "person"
|
||||
assert str(q1.message_chain) == "hello world"
|
||||
q1 = text_query('hello world')
|
||||
assert q1.launcher_type.value == 'person'
|
||||
assert str(q1.message_chain) == 'hello world'
|
||||
|
||||
# Group query
|
||||
from tests.factories import group_text_query
|
||||
q2 = group_text_query("hello group", group_id=88888)
|
||||
assert q2.launcher_type.value == "group"
|
||||
|
||||
q2 = group_text_query('hello group', group_id=88888)
|
||||
assert q2.launcher_type.value == 'group'
|
||||
assert q2.launcher_id == 88888
|
||||
|
||||
# Command query
|
||||
from tests.factories import command_query
|
||||
q3 = command_query("help", prefix="/")
|
||||
assert str(q3.message_chain) == "/help"
|
||||
|
||||
q3 = command_query('help', prefix='/')
|
||||
assert str(q3.message_chain) == '/help'
|
||||
|
||||
# Mention query
|
||||
from tests.factories import mention_query
|
||||
q4 = mention_query("hi", target="test-bot", group_id=77777)
|
||||
assert q4.launcher_type.value == "group"
|
||||
|
||||
q4 = mention_query('hi', target='test-bot', group_id=77777)
|
||||
assert q4.launcher_type.value == 'group'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_platform_send_failure(self):
|
||||
"""Test FakePlatform simulates send failure."""
|
||||
platform = FakePlatform().send_failure()
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
with pytest.raises(Exception, match="Platform send failure"):
|
||||
with pytest.raises(Exception, match='Platform send failure'):
|
||||
await platform.reply_message(
|
||||
query.message_event,
|
||||
text_chain("response"),
|
||||
text_chain('response'),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_platform_adapter(self):
|
||||
"""Test mock_platform_adapter helper."""
|
||||
platform = FakePlatform(bot_account_id="bot-123")
|
||||
platform = FakePlatform(bot_account_id='bot-123')
|
||||
adapter = mock_platform_adapter(platform)
|
||||
|
||||
assert adapter.bot_account_id == "bot-123"
|
||||
assert adapter.bot_account_id == 'bot-123'
|
||||
assert adapter._fake_platform is platform
|
||||
|
||||
# Test reply_message is wired
|
||||
from tests.factories.message import text_chain
|
||||
|
||||
query = text_query("test")
|
||||
await adapter.reply_message(query.message_event, text_chain("response"))
|
||||
query = text_query('test')
|
||||
await adapter.reply_message(query.message_event, text_chain('response'))
|
||||
|
||||
# Verify platform captured it
|
||||
assert len(platform.get_outbound_messages()) == 1
|
||||
@@ -293,18 +296,18 @@ class TestMessageFlowIntegration:
|
||||
Note: This does NOT run actual LangBot pipeline stages.
|
||||
"""
|
||||
# Setup
|
||||
platform = FakePlatform(bot_account_id="test-bot")
|
||||
platform = FakePlatform(bot_account_id='test-bot')
|
||||
provider = fake_provider_pong()
|
||||
model = fake_model(provider=provider)
|
||||
|
||||
# Create inbound message
|
||||
query = text_query("ping")
|
||||
query = text_query('ping')
|
||||
|
||||
# Simulate provider processing
|
||||
response = await provider.invoke_llm(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
messages=[{'role': 'user', 'content': 'ping'}],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
@@ -321,16 +324,16 @@ class TestMessageFlowIntegration:
|
||||
# Verify platform captured outbound
|
||||
outbound = platform.get_outbound_messages()
|
||||
assert len(outbound) == 1
|
||||
assert outbound[0]["type"] == "reply"
|
||||
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
|
||||
assert outbound[0]['type'] == 'reply'
|
||||
assert str(outbound[0]['message']) == FakeProvider.PONG_RESPONSE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_message_flow(self):
|
||||
"""Smoke test: streaming message flow."""
|
||||
platform = FakePlatform().supports_streaming()
|
||||
provider = FakeProvider().returns_streaming(["Hello", " there"])
|
||||
provider = FakeProvider().returns_streaming(['Hello', ' there'])
|
||||
model = fake_model(provider=provider)
|
||||
query = text_query("hi")
|
||||
query = text_query('hi')
|
||||
|
||||
chunks = []
|
||||
async for chunk in provider.invoke_llm_stream(
|
||||
@@ -344,8 +347,8 @@ class TestMessageFlowIntegration:
|
||||
|
||||
# Verify streaming worked
|
||||
assert len(chunks) == 2
|
||||
full_content = "".join(c.content for c in chunks)
|
||||
assert full_content == "Hello there"
|
||||
full_content = ''.join(c.content for c in chunks)
|
||||
assert full_content == 'Hello there'
|
||||
|
||||
# Verify platform supports streaming
|
||||
assert await platform.is_stream_output_supported() is True
|
||||
assert await platform.is_stream_output_supported() is True
|
||||
|
||||
@@ -15,22 +15,12 @@ import pathlib
|
||||
# Resolve project root (one level up from tests/)
|
||||
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
VULN_FILE = (
|
||||
_PROJECT_ROOT
|
||||
/ "src"
|
||||
/ "langbot"
|
||||
/ "pkg"
|
||||
/ "api"
|
||||
/ "http"
|
||||
/ "controller"
|
||||
/ "groups"
|
||||
/ "system.py"
|
||||
)
|
||||
VULN_FILE = _PROJECT_ROOT / 'src' / 'langbot' / 'pkg' / 'api' / 'http' / 'controller' / 'groups' / 'system.py'
|
||||
|
||||
|
||||
def test_no_exec_call_in_system_controller():
|
||||
"""Verify there is no exec() call in system.py that takes user input."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
with open(VULN_FILE, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source)
|
||||
@@ -40,27 +30,26 @@ def test_no_exec_call_in_system_controller():
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
# Match bare exec() call
|
||||
if isinstance(func, ast.Name) and func.id == "exec":
|
||||
if isinstance(func, ast.Name) and func.id == 'exec':
|
||||
exec_calls.append(node.lineno)
|
||||
|
||||
assert len(exec_calls) == 0, (
|
||||
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
|
||||
"User-supplied code must never be passed to exec()."
|
||||
f'Found exec() call(s) at line(s) {exec_calls} in system.py. User-supplied code must never be passed to exec().'
|
||||
)
|
||||
|
||||
|
||||
def test_no_debug_exec_route():
|
||||
"""Verify the /debug/exec route is not registered."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
with open(VULN_FILE, 'r') as f:
|
||||
source = f.read()
|
||||
|
||||
assert "debug/exec" not in source, (
|
||||
"The /debug/exec route still exists in system.py. "
|
||||
"This endpoint allows arbitrary code execution and must be removed."
|
||||
assert 'debug/exec' not in source, (
|
||||
'The /debug/exec route still exists in system.py. '
|
||||
'This endpoint allows arbitrary code execution and must be removed.'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
test_no_exec_call_in_system_controller()
|
||||
test_no_debug_exec_route()
|
||||
print("All tests passed!")
|
||||
print('All tests passed!')
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Unit tests for LangBot API HTTP service layer."""
|
||||
"""Unit tests for LangBot API HTTP service layer."""
|
||||
|
||||
@@ -13,4 +13,4 @@ Does NOT:
|
||||
- Call real provider/platform/network
|
||||
|
||||
Uses tests.factories.FakeApp as base mock application.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -132,9 +132,7 @@ class TestApiKeyServiceCreateApiKey:
|
||||
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
|
||||
result = await service.create_api_key('New Key', 'Test description')
|
||||
|
||||
assert insert_params == [
|
||||
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
|
||||
]
|
||||
assert insert_params == [{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}]
|
||||
assert result['key'].startswith('lbk_')
|
||||
assert result['key'] == 'lbk_fixed-token'
|
||||
assert result['name'] == 'New Key'
|
||||
|
||||
@@ -303,13 +303,7 @@ class TestBotServiceCreateBot:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_bots': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_bots': 2}}}
|
||||
ap.platform_mgr = SimpleNamespace()
|
||||
ap.platform_mgr.load_bot = AsyncMock()
|
||||
|
||||
@@ -318,9 +312,7 @@ class TestBotServiceCreateBot:
|
||||
bot2 = _create_mock_bot(bot_uuid='uuid-2')
|
||||
mock_result = _create_mock_result([bot1, bot2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Bot 1'})
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
@@ -352,6 +344,7 @@ class TestBotServiceCreateBot:
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -362,9 +355,7 @@ class TestBotServiceCreateBot:
|
||||
return bot_result # Get bot
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Bot'})
|
||||
|
||||
service = BotService(ap)
|
||||
|
||||
@@ -397,6 +388,7 @@ class TestBotServiceCreateBot:
|
||||
bot_result.first = Mock(return_value=_create_mock_bot())
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -492,6 +484,7 @@ class TestBotServiceUpdateBot:
|
||||
pipeline_result.first = Mock(return_value=mock_pipeline)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -582,10 +575,9 @@ class TestBotServiceListEventLogs:
|
||||
# Mock runtime bot with logger
|
||||
runtime_bot = SimpleNamespace()
|
||||
runtime_bot.logger = SimpleNamespace()
|
||||
runtime_bot.logger.get_logs = AsyncMock(return_value=(
|
||||
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
|
||||
5
|
||||
))
|
||||
runtime_bot.logger.get_logs = AsyncMock(
|
||||
return_value=([SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))], 5)
|
||||
)
|
||||
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
|
||||
|
||||
service = BotService(ap)
|
||||
@@ -646,11 +638,7 @@ class TestBotServiceSendMessage:
|
||||
service = BotService(ap)
|
||||
|
||||
# Execute with valid message chain format
|
||||
message_chain_data = {
|
||||
'messages': [
|
||||
{'type': 'text', 'data': {'text': 'Hello'}}
|
||||
]
|
||||
}
|
||||
message_chain_data = {'messages': [{'type': 'text', 'data': {'text': 'Hello'}}]}
|
||||
|
||||
# Patch the import location - the module imports inside the function
|
||||
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
|
||||
|
||||
@@ -6,6 +6,7 @@ Tests cover:
|
||||
- Knowledge engine discovery
|
||||
- File operations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -52,9 +53,7 @@ class TestGetKnowledgeBases:
|
||||
"""Test that it returns all knowledge base details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
|
||||
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
|
||||
)
|
||||
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[{'uuid': 'kb1', 'name': 'KB1'}])
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_bases()
|
||||
@@ -83,9 +82,7 @@ class TestGetKnowledgeBase:
|
||||
"""Test that it returns specific KB details."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'KB1'}
|
||||
)
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'KB1'})
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_knowledge_base('kb1')
|
||||
@@ -153,9 +150,7 @@ class TestCreateKnowledgeBase:
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
await service.create_knowledge_base({
|
||||
'knowledge_engine_plugin_id': 'author/engine'
|
||||
})
|
||||
await service.create_knowledge_base({'knowledge_engine_plugin_id': 'author/engine'})
|
||||
|
||||
# Check that default name 'Untitled' was used
|
||||
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
|
||||
@@ -170,20 +165,21 @@ class TestUpdateKnowledgeBase:
|
||||
"""Test that only mutable fields are updated."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
|
||||
return_value={'uuid': 'kb1', 'name': 'Updated'}
|
||||
)
|
||||
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value={'uuid': 'kb1', 'name': 'Updated'})
|
||||
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
|
||||
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
|
||||
# Pass both mutable and immutable fields
|
||||
await service.update_knowledge_base('kb1', {
|
||||
'name': 'New Name',
|
||||
'description': 'New desc',
|
||||
'uuid': 'should_be_filtered', # immutable
|
||||
})
|
||||
await service.update_knowledge_base(
|
||||
'kb1',
|
||||
{
|
||||
'name': 'New Name',
|
||||
'description': 'New desc',
|
||||
'uuid': 'should_be_filtered', # immutable
|
||||
},
|
||||
)
|
||||
|
||||
# Check that only mutable fields were passed to update
|
||||
call_args = mock_app.persistence_mgr.execute_async.call_args
|
||||
@@ -288,9 +284,7 @@ class TestListKnowledgeEngines:
|
||||
"""Test that it returns empty list and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
|
||||
side_effect=Exception('Connection error')
|
||||
)
|
||||
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(side_effect=Exception('Connection error'))
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.list_knowledge_engines()
|
||||
@@ -386,12 +380,10 @@ class TestGetEngineSchemas:
|
||||
"""Test that it returns empty dict and logs warning on exception."""
|
||||
knowledge_module = get_knowledge_service_module()
|
||||
mock_app = create_mock_app()
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
|
||||
side_effect=Exception('Plugin error')
|
||||
)
|
||||
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(side_effect=Exception('Plugin error'))
|
||||
|
||||
service = knowledge_module.KnowledgeService(mock_app)
|
||||
result = await service.get_engine_creation_schema('author/engine')
|
||||
|
||||
assert result == {}
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
mock_app.logger.warning.assert_called_once()
|
||||
|
||||
@@ -174,9 +174,7 @@ class TestMaintenanceServiceGetStorageAnalysis:
|
||||
# Setup
|
||||
ap = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
|
||||
}
|
||||
ap.instance_config.data = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.logger = SimpleNamespace()
|
||||
ap.logger.warning = Mock()
|
||||
@@ -292,12 +290,8 @@ class TestMaintenanceServiceGetStorageAnalysis:
|
||||
service._file_count = Mock(return_value=0)
|
||||
service._monitoring_counts = AsyncMock(return_value={})
|
||||
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[
|
||||
{'key': 'old_file', 'size_bytes': 100}
|
||||
])
|
||||
service._expired_log_candidates = Mock(return_value=[
|
||||
{'name': 'old_log', 'size_bytes': 50}
|
||||
])
|
||||
service._expired_uploaded_candidates = AsyncMock(return_value=[{'key': 'old_file', 'size_bytes': 100}])
|
||||
service._expired_log_candidates = Mock(return_value=[{'name': 'old_log', 'size_bytes': 50}])
|
||||
|
||||
# Execute
|
||||
result = await service.get_storage_analysis()
|
||||
@@ -367,6 +361,7 @@ class TestMaintenanceServiceBinaryStorageStats:
|
||||
size_result = _create_mock_result(scalar_value=5000)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -396,6 +391,7 @@ class TestMaintenanceServiceBinaryStorageStats:
|
||||
count_result = _create_mock_result(scalar_value=5)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -821,4 +817,4 @@ class TestMaintenanceServiceExpiredLocalUploadCandidates:
|
||||
result = service._expired_local_upload_candidates(7, include_paths=True)
|
||||
|
||||
# Verify - path included
|
||||
assert 'path' in result[0]
|
||||
assert 'path' in result[0]
|
||||
|
||||
@@ -186,13 +186,7 @@ class TestMCPServiceCreateMCPServer:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_extensions': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_extensions': 2}}}
|
||||
ap.plugin_connector = SimpleNamespace()
|
||||
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
|
||||
|
||||
@@ -252,6 +246,7 @@ class TestMCPServiceCreateMCPServer:
|
||||
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -361,6 +356,7 @@ class TestMCPServiceUpdateMCPServer:
|
||||
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -394,6 +390,7 @@ class TestMCPServiceUpdateMCPServer:
|
||||
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -432,6 +429,7 @@ class TestMCPServiceUpdateMCPServer:
|
||||
|
||||
# Mock for: first select -> update -> second select (for updated server)
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -465,6 +463,7 @@ class TestMCPServiceUpdateMCPServer:
|
||||
|
||||
# Mock execute for select and update
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -499,6 +498,7 @@ class TestMCPServiceDeleteMCPServer:
|
||||
server = _create_mock_mcp_server(name='Server to Delete')
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -530,6 +530,7 @@ class TestMCPServiceDeleteMCPServer:
|
||||
server = _create_mock_mcp_server(name='Not in Sessions')
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -559,6 +560,7 @@ class TestMCPServiceDeleteMCPServer:
|
||||
|
||||
# No server found
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -596,9 +598,7 @@ class TestMCPServiceTestMCPServer:
|
||||
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=123)
|
||||
)
|
||||
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=123))
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
@@ -634,9 +634,7 @@ class TestMCPServiceTestMCPServer:
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
ap.task_mgr.create_user_task = Mock(
|
||||
return_value=SimpleNamespace(id=456)
|
||||
)
|
||||
ap.task_mgr.create_user_task = Mock(return_value=SimpleNamespace(id=456))
|
||||
|
||||
service = MCPService(ap)
|
||||
|
||||
@@ -645,4 +643,4 @@ class TestMCPServiceTestMCPServer:
|
||||
|
||||
# Verify - load_mcp_server called
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
||||
assert task_id == 456
|
||||
assert task_id == 456
|
||||
|
||||
@@ -167,6 +167,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
mock_provider_result = _create_mock_result([])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
return mock_result if call_count == 0 else mock_provider_result
|
||||
|
||||
@@ -200,6 +201,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -239,6 +241,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -279,6 +282,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -337,9 +341,7 @@ class TestLLMModelsServiceGetLLMModelsByProvider:
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'model-1', 'name': 'Model 1'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'model-1', 'name': 'Model 1'})
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
@@ -371,12 +373,14 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'name': 'New LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
model_uuid = await service.create_llm_model(
|
||||
{
|
||||
'name': 'New LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
@@ -400,13 +404,16 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_llm_model({
|
||||
'uuid': 'preserved-uuid',
|
||||
'name': 'Preserved UUID Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}, preserve_uuid=True)
|
||||
model_uuid = await service.create_llm_model(
|
||||
{
|
||||
'uuid': 'preserved-uuid',
|
||||
'name': 'Preserved UUID Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
},
|
||||
preserve_uuid=True,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
@@ -459,12 +466,14 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_llm_model({
|
||||
'name': 'No Provider Model',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
await service.create_llm_model(
|
||||
{
|
||||
'name': 'No Provider Model',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
async def test_create_llm_model_with_provider_data(self):
|
||||
"""Creates provider when provider data provided."""
|
||||
@@ -490,16 +499,18 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute - with provider data (no UUID)
|
||||
result_uuid = await service.create_llm_model({
|
||||
'name': 'Model with New Provider',
|
||||
'provider': {
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
},
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
})
|
||||
result_uuid = await service.create_llm_model(
|
||||
{
|
||||
'name': 'Model with New Provider',
|
||||
'provider': {
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
},
|
||||
'abilities': [],
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify - provider_service was called and UUID generated
|
||||
ap.provider_service.find_or_create_provider.assert_called_once()
|
||||
@@ -525,11 +536,14 @@ class TestLLMModelsServiceUpdateLLMModel:
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_llm_model('existing-uuid', {
|
||||
'uuid': 'should-be-removed',
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
})
|
||||
await service.update_llm_model(
|
||||
'existing-uuid',
|
||||
{
|
||||
'uuid': 'should-be-removed',
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
},
|
||||
)
|
||||
|
||||
# Verify - remove and load called
|
||||
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
|
||||
@@ -549,10 +563,13 @@ class TestLLMModelsServiceUpdateLLMModel:
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.update_llm_model('model-uuid', {
|
||||
'name': 'Update',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
await service.update_llm_model(
|
||||
'model-uuid',
|
||||
{
|
||||
'name': 'Update',
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
},
|
||||
)
|
||||
|
||||
async def test_update_llm_model_reloads_context_length_as_column(self):
|
||||
"""Updates runtime model with context_length outside extra_args."""
|
||||
@@ -618,9 +635,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'embedding-uuid', 'name': 'Test'})
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
@@ -643,6 +658,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModels:
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -683,6 +699,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModel:
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -742,11 +759,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_embedding_model({
|
||||
'name': 'New Embedding',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
model_uuid = await service.create_embedding_model(
|
||||
{
|
||||
'name': 'New Embedding',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
@@ -767,11 +786,13 @@ class TestEmbeddingModelsServiceCreateEmbeddingModel:
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_embedding_model({
|
||||
'name': 'No Provider Embedding',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
await service.create_embedding_model(
|
||||
{
|
||||
'name': 'No Provider Embedding',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
|
||||
@@ -829,6 +850,7 @@ class TestRerankModelsServiceGetRerankModels:
|
||||
mock_provider_result = _create_mock_result([provider])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -869,6 +891,7 @@ class TestRerankModelsServiceGetRerankModel:
|
||||
mock_provider_result = _create_mock_result([], first_item=provider)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -928,11 +951,13 @@ class TestRerankModelsServiceCreateRerankModel:
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
# Execute
|
||||
model_uuid = await service.create_rerank_model({
|
||||
'name': 'New Rerank',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
})
|
||||
model_uuid = await service.create_rerank_model(
|
||||
{
|
||||
'name': 'New Rerank',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert model_uuid is not None
|
||||
@@ -952,11 +977,13 @@ class TestRerankModelsServiceCreateRerankModel:
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(Exception, match='provider not found'):
|
||||
await service.create_rerank_model({
|
||||
'name': 'No Provider Rerank',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
})
|
||||
await service.create_rerank_model(
|
||||
{
|
||||
'name': 'No Provider Rerank',
|
||||
'provider_uuid': 'nonexistent',
|
||||
'extra_args': {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestRerankModelsServiceDeleteRerankModel:
|
||||
@@ -995,9 +1022,7 @@ class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'emb-1', 'name': 'Embedding 1'})
|
||||
|
||||
service = EmbeddingModelsService(ap)
|
||||
|
||||
@@ -1022,9 +1047,7 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
|
||||
mock_result = _create_mock_result([model1, model2])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'})
|
||||
|
||||
service = RerankModelsService(ap)
|
||||
|
||||
@@ -1042,14 +1065,10 @@ class TestValidateProviderSupports:
|
||||
def _make_ap(requester_name: str, support_type):
|
||||
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
|
||||
manifest = SimpleNamespace(spec={'support_type': support_type})
|
||||
runtime_provider = SimpleNamespace(
|
||||
provider_entity=SimpleNamespace(requester=requester_name)
|
||||
)
|
||||
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester=requester_name))
|
||||
model_mgr = SimpleNamespace(
|
||||
provider_dict={'p1': runtime_provider},
|
||||
get_available_requester_manifest_by_name=lambda name: manifest
|
||||
if name == requester_name
|
||||
else None,
|
||||
get_available_requester_manifest_by_name=lambda name: manifest if name == requester_name else None,
|
||||
)
|
||||
return SimpleNamespace(model_mgr=model_mgr)
|
||||
|
||||
@@ -1066,9 +1085,7 @@ class TestValidateProviderSupports:
|
||||
async def test_allows_when_support_type_missing(self):
|
||||
# Manifest without support_type must not block (backward compatible)
|
||||
manifest = SimpleNamespace(spec={})
|
||||
runtime_provider = SimpleNamespace(
|
||||
provider_entity=SimpleNamespace(requester='legacy')
|
||||
)
|
||||
runtime_provider = SimpleNamespace(provider_entity=SimpleNamespace(requester='legacy'))
|
||||
model_mgr = SimpleNamespace(
|
||||
provider_dict={'p1': runtime_provider},
|
||||
get_available_requester_manifest_by_name=lambda name: manifest,
|
||||
|
||||
@@ -215,13 +215,7 @@ class TestPipelineServiceCreatePipeline:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
@@ -229,9 +223,7 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'})
|
||||
|
||||
service = PipelineService(ap)
|
||||
|
||||
@@ -258,14 +250,14 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
# Mock persistence for insert
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
|
||||
|
||||
# Mock the file read for default config - patch at the utils module level
|
||||
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
with patch(
|
||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
||||
):
|
||||
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
# Verify
|
||||
@@ -286,7 +278,9 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
|
||||
)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock()
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
@@ -296,7 +290,9 @@ class TestPipelineServiceCreatePipeline:
|
||||
# Mock the file read
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
with patch(
|
||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
||||
):
|
||||
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
|
||||
|
||||
# Verify - execute was called
|
||||
@@ -316,10 +312,12 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
service.get_pipeline = AsyncMock(return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
})
|
||||
service.get_pipeline = AsyncMock(
|
||||
return_value={
|
||||
'uuid': 'new-uuid',
|
||||
'extensions_preferences': {},
|
||||
}
|
||||
)
|
||||
|
||||
insert_params = []
|
||||
|
||||
@@ -339,7 +337,9 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
default_config = {}
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
|
||||
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
|
||||
with patch(
|
||||
'langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'
|
||||
):
|
||||
await service.create_pipeline({'name': 'New Pipeline'})
|
||||
|
||||
assert len(insert_params) == 1
|
||||
@@ -353,6 +353,7 @@ class TestPipelineServiceCreatePipeline:
|
||||
|
||||
class _MockResultWithBots:
|
||||
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
|
||||
|
||||
def __init__(self, bots_list):
|
||||
self._bots_list = bots_list
|
||||
|
||||
@@ -428,6 +429,7 @@ class TestPipelineServiceUpdatePipeline:
|
||||
# 1. UPDATE (line 125) - returns Mock (no result needed)
|
||||
# 2. SELECT bots (line 136) - returns bot_result with .all()
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -528,13 +530,7 @@ class TestPipelineServiceCopyPipeline:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.instance_config = SimpleNamespace()
|
||||
ap.instance_config.data = {
|
||||
'system': {
|
||||
'limitation': {
|
||||
'max_pipelines': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': 2}}}
|
||||
ap.pipeline_mgr = SimpleNamespace()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
ap.ver_mgr = SimpleNamespace()
|
||||
@@ -542,10 +538,12 @@ class TestPipelineServiceCopyPipeline:
|
||||
|
||||
service = PipelineService(ap)
|
||||
# Mock get_pipelines to return 2 pipelines
|
||||
service.get_pipelines = AsyncMock(return_value=[
|
||||
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
||||
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
||||
])
|
||||
service.get_pipelines = AsyncMock(
|
||||
return_value=[
|
||||
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
|
||||
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
|
||||
]
|
||||
)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Maximum number of pipelines'):
|
||||
@@ -642,9 +640,7 @@ class TestPipelineServiceCopyPipeline:
|
||||
service = PipelineService(ap)
|
||||
service.get_pipelines = AsyncMock(return_value=[])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'uuid': 'copy-uuid', 'is_default': False}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
||||
|
||||
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
|
||||
|
||||
@@ -681,11 +677,10 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
ap.pipeline_mgr.remove_pipeline = AsyncMock()
|
||||
ap.pipeline_mgr.load_pipeline = AsyncMock()
|
||||
|
||||
original_pipeline = _create_mock_pipeline(
|
||||
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
|
||||
)
|
||||
original_pipeline = _create_mock_pipeline(extensions_preferences={'enable_all_plugins': True, 'plugins': []})
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -700,7 +695,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -711,7 +706,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
'extensions_preferences': {
|
||||
'enable_all_plugins': False,
|
||||
'plugins': [{'plugin_uuid': 'plugin-1'}],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -738,6 +733,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
original_pipeline = _create_mock_pipeline()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -752,7 +748,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
'extensions_preferences': {
|
||||
'enable_all_mcp_servers': False,
|
||||
'mcp_servers': ['mcp-server-1'],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -794,6 +790,7 @@ class TestPipelineServiceUpdatePipelineExtensions:
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
@@ -245,12 +245,14 @@ class TestModelProviderServiceCreateProvider:
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
provider_uuid = await service.create_provider({
|
||||
'name': 'New Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
provider_uuid = await service.create_provider(
|
||||
{
|
||||
'name': 'New Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
# Verify - UUID is generated
|
||||
assert provider_uuid is not None
|
||||
@@ -274,12 +276,14 @@ class TestModelProviderServiceCreateProvider:
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
result_uuid = await service.create_provider({
|
||||
'name': 'Runtime Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
})
|
||||
result_uuid = await service.create_provider(
|
||||
{
|
||||
'name': 'Runtime Provider',
|
||||
'requester': 'openai',
|
||||
'base_url': 'https://api.openai.com',
|
||||
'api_keys': ['key'],
|
||||
}
|
||||
)
|
||||
|
||||
# Verify - provider added to runtime dict and UUID generated
|
||||
ap.model_mgr.load_provider.assert_called_once()
|
||||
@@ -302,10 +306,13 @@ class TestModelProviderServiceUpdateProvider:
|
||||
service = ModelProviderService(ap)
|
||||
|
||||
# Execute
|
||||
await service.update_provider('existing-uuid', {
|
||||
'uuid': 'should-be-removed', # Will be removed
|
||||
'name': 'Updated Name',
|
||||
})
|
||||
await service.update_provider(
|
||||
'existing-uuid',
|
||||
{
|
||||
'uuid': 'should-be-removed', # Will be removed
|
||||
'name': 'Updated Name',
|
||||
},
|
||||
)
|
||||
|
||||
# Verify - reload called
|
||||
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
|
||||
@@ -364,6 +371,7 @@ class TestModelProviderServiceDeleteProvider:
|
||||
rerank_result.first = Mock(return_value=None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -396,6 +404,7 @@ class TestModelProviderServiceDeleteProvider:
|
||||
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -454,6 +463,7 @@ class TestModelProviderServiceGetProviderModelCounts:
|
||||
rerank_result.scalar = Mock(return_value=1)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -637,9 +647,7 @@ class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
|
||||
await service.update_space_model_provider_api_keys('space-api-key')
|
||||
|
||||
# Verify - update and reload called for Space provider UUID
|
||||
ap.model_mgr.reload_provider.assert_called_once_with(
|
||||
'00000000-0000-0000-0000-000000000000'
|
||||
)
|
||||
ap.model_mgr.reload_provider.assert_called_once_with('00000000-0000-0000-0000-000000000000')
|
||||
|
||||
|
||||
class TestModelProviderServiceScanProviderModels:
|
||||
@@ -795,9 +803,7 @@ class TestModelProviderServiceScanProviderModels:
|
||||
runtime_provider.token_mgr = Mock()
|
||||
runtime_provider.token_mgr.get_token = Mock(return_value='token')
|
||||
runtime_provider.token_mgr.tokens = ['token']
|
||||
runtime_provider.requester.scan_models = AsyncMock(
|
||||
side_effect=NotImplementedError('scan not supported')
|
||||
)
|
||||
runtime_provider.requester.scan_models = AsyncMock(side_effect=NotImplementedError('scan not supported'))
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
@@ -848,9 +854,7 @@ class TestModelProviderServiceScanProviderModels:
|
||||
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
|
||||
|
||||
# Mock existing LLM model
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
|
||||
return_value=[{'name': 'Existing Model'}]
|
||||
)
|
||||
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[{'name': 'Existing Model'}])
|
||||
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
service = ModelProviderService(ap)
|
||||
@@ -863,4 +867,4 @@ class TestModelProviderServiceScanProviderModels:
|
||||
assert existing_model['already_added'] is True
|
||||
|
||||
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
|
||||
assert new_model['already_added'] is False
|
||||
assert new_model['already_added'] is False
|
||||
|
||||
@@ -393,14 +393,16 @@ class TestSpaceServiceRefreshToken:
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
@@ -429,10 +431,12 @@ class TestSpaceServiceRefreshToken:
|
||||
# Mock HTTP response with error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 1,
|
||||
'msg': 'Invalid refresh token',
|
||||
})
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
'code': 1,
|
||||
'msg': 'Invalid refresh token',
|
||||
}
|
||||
)
|
||||
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
@@ -489,14 +493,16 @@ class TestSpaceServiceExchangeOAuthCode:
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'access_token': 'new_access_token',
|
||||
'refresh_token': 'new_refresh_token',
|
||||
'expires_in': 3600,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
@@ -555,13 +561,15 @@ class TestSpaceServiceGetUserInfoRaw:
|
||||
# Mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'email': 'test@example.com',
|
||||
'credits': 100,
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'email': 'test@example.com',
|
||||
'credits': 100,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
@@ -669,27 +677,29 @@ class TestSpaceServiceGetModels:
|
||||
# Mock HTTP response with proper model data matching SpaceModel schema
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'models': [
|
||||
{
|
||||
'uuid': 'uuid-1',
|
||||
'model_id': 'model-1',
|
||||
'provider': 'provider-1',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
{
|
||||
'uuid': 'uuid-2',
|
||||
'model_id': 'model-2',
|
||||
'provider': 'provider-2',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
]
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
'code': 0,
|
||||
'data': {
|
||||
'models': [
|
||||
{
|
||||
'uuid': 'uuid-1',
|
||||
'model_id': 'model-1',
|
||||
'provider': 'provider-1',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
{
|
||||
'uuid': 'uuid-2',
|
||||
'model_id': 'model-2',
|
||||
'provider': 'provider-2',
|
||||
'category': 'chat',
|
||||
'status': 'active',
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
|
||||
mock_session_obj = MagicMock()
|
||||
@@ -775,4 +785,4 @@ class TestSpaceServiceCreditsCache:
|
||||
# Verify - cache updated
|
||||
assert result == 500
|
||||
assert 'test@example.com' in service._credits_cache
|
||||
assert service._credits_cache['test@example.com'][0] == 500
|
||||
assert service._credits_cache['test@example.com'][0] == 500
|
||||
|
||||
@@ -495,6 +495,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -565,6 +566,7 @@ class TestUserServiceCreateOrUpdateSpaceUser:
|
||||
|
||||
# First call (line 138) returns None, second call (line 194) returns new_user
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_by_space_uuid(uuid):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -605,4 +607,4 @@ class TestUserServiceCreateUserLock:
|
||||
|
||||
# Verify lock exists
|
||||
assert hasattr(service, '_create_user_lock')
|
||||
assert service._create_user_lock is not None
|
||||
assert service._create_user_lock is not None
|
||||
|
||||
@@ -132,6 +132,7 @@ class TestWebhookServiceCreateWebhook:
|
||||
|
||||
# execute_async returns different results
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -181,6 +182,7 @@ class TestWebhookServiceCreateWebhook:
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -217,6 +219,7 @@ class TestWebhookServiceCreateWebhook:
|
||||
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(query):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -225,9 +228,7 @@ class TestWebhookServiceCreateWebhook:
|
||||
return _create_mock_result(first_item=created_webhook)
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={'id': 1, 'enabled': False}
|
||||
)
|
||||
ap.persistence_mgr.serialize_model = Mock(return_value={'id': 1, 'enabled': False})
|
||||
|
||||
service = WebhookService(ap)
|
||||
|
||||
@@ -503,4 +504,4 @@ class TestWebhookServiceGetEnabledWebhooks:
|
||||
result = await service.get_enabled_webhooks()
|
||||
|
||||
# Verify - should be empty (SQL would filter disabled)
|
||||
assert result == []
|
||||
assert result == []
|
||||
|
||||
@@ -407,7 +407,9 @@ def test_box_service_forced_template_ignores_pipeline_config():
|
||||
launcher_type='person',
|
||||
launcher_id='test_user',
|
||||
sender_id='test_user',
|
||||
pipeline_config={'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}},
|
||||
pipeline_config={
|
||||
'ai': {'local-agent': {'box-session-id-template': '{launcher_type}_{launcher_id}_{sender_id}'}}
|
||||
},
|
||||
)
|
||||
|
||||
assert service.resolve_box_session_id(query) == 'global'
|
||||
@@ -1527,9 +1529,7 @@ class TestBuildSkillExtraMounts:
|
||||
{'host_path': '/box/skills/b', 'mount_path': '/workspace/.skills/b', 'mode': 'rw'},
|
||||
]
|
||||
# No skill is dropped, so no "missing" warning should be logged.
|
||||
assert not any(
|
||||
'package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list
|
||||
)
|
||||
assert not any('package_root missing' in str(call.args[0]) for call in logger.warning.call_args_list)
|
||||
|
||||
def test_skips_skill_with_empty_package_root(self):
|
||||
logger = Mock()
|
||||
|
||||
@@ -54,9 +54,7 @@ def test_classify_python_workspace_detects_package_and_requirements():
|
||||
def test_wrap_python_command_with_env_contains_bootstrap_and_command():
|
||||
command = wrap_python_command_with_env('python script.py')
|
||||
|
||||
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
|
||||
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
|
||||
assert 'kill -0 "$_LB_LOCK_OWNER"' in command
|
||||
assert 'python -m venv "$_LB_VENV_DIR"' in command
|
||||
assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command
|
||||
assert command.rstrip().endswith('python script.py')
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Unit tests for command module
|
||||
# Unit tests for command module
|
||||
|
||||
@@ -529,4 +529,4 @@ class TestEmptyAndEdgeInputs:
|
||||
|
||||
# Should yield CommandNotFoundError (no such command registered)
|
||||
assert len(results) == 1
|
||||
assert results[0].error is not None
|
||||
assert results[0].error is not None
|
||||
|
||||
@@ -197,6 +197,7 @@ class TestCommandOperatorBase:
|
||||
op = TestOperator(None)
|
||||
# Should not raise
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(op.initialize())
|
||||
|
||||
def test_execute_is_abstract(self):
|
||||
@@ -299,4 +300,4 @@ class TestMultipleOperators:
|
||||
yield None
|
||||
|
||||
assert AdminOperator.lowest_privilege == 2
|
||||
assert SubOperator.lowest_privilege == 1
|
||||
assert SubOperator.lowest_privilege == 1
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestYAMLConfigFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_yaml_loads(self, tmp_path):
|
||||
"""Valid YAML config should load correctly."""
|
||||
config_file = tmp_path / "test_config.yaml"
|
||||
config_file = tmp_path / 'test_config.yaml'
|
||||
|
||||
# Write valid YAML
|
||||
config_file.write_text("""
|
||||
@@ -51,7 +51,7 @@ settings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_yaml_raises_error(self, tmp_path):
|
||||
"""Invalid YAML should raise clear error."""
|
||||
config_file = tmp_path / "invalid.yaml"
|
||||
config_file = tmp_path / 'invalid.yaml'
|
||||
|
||||
# Write invalid YAML (unclosed bracket)
|
||||
config_file.write_text("""
|
||||
@@ -67,13 +67,13 @@ settings:
|
||||
template_data={'name': 'default'},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Syntax error"):
|
||||
with pytest.raises(Exception, match='Syntax error'):
|
||||
await yaml_file.load(completion=False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_config_creates_from_template(self, tmp_path):
|
||||
"""Missing config file should be created from template."""
|
||||
config_file = tmp_path / "new_config.yaml"
|
||||
config_file = tmp_path / 'new_config.yaml'
|
||||
|
||||
# File doesn't exist yet
|
||||
assert not config_file.exists()
|
||||
@@ -92,7 +92,7 @@ settings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_completion(self, tmp_path):
|
||||
"""Config should be completed with template defaults."""
|
||||
config_file = tmp_path / "partial.yaml"
|
||||
config_file = tmp_path / 'partial.yaml'
|
||||
|
||||
# Write partial config missing some template keys
|
||||
config_file.write_text("""
|
||||
@@ -115,7 +115,7 @@ name: custom_name
|
||||
@pytest.mark.asyncio
|
||||
async def test_yaml_save(self, tmp_path):
|
||||
"""YAML config can be saved."""
|
||||
config_file = tmp_path / "save_test.yaml"
|
||||
config_file = tmp_path / 'save_test.yaml'
|
||||
|
||||
yaml_file = YAMLConfigFile(
|
||||
str(config_file),
|
||||
@@ -131,7 +131,7 @@ name: custom_name
|
||||
|
||||
def test_yaml_save_sync(self, tmp_path):
|
||||
"""YAML config can be saved synchronously."""
|
||||
config_file = tmp_path / "sync_save.yaml"
|
||||
config_file = tmp_path / 'sync_save.yaml'
|
||||
|
||||
yaml_file = YAMLConfigFile(
|
||||
str(config_file),
|
||||
@@ -151,14 +151,18 @@ class TestJSONConfigFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_json_loads(self, tmp_path):
|
||||
"""Valid JSON config should load correctly."""
|
||||
config_file = tmp_path / "test_config.json"
|
||||
config_file = tmp_path / 'test_config.json'
|
||||
|
||||
# Write valid JSON
|
||||
config_file.write_text(json.dumps({
|
||||
'name': 'json_app',
|
||||
'version': '1.0',
|
||||
'settings': {'debug': True, 'port': 8080},
|
||||
}))
|
||||
config_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
'name': 'json_app',
|
||||
'version': '1.0',
|
||||
'settings': {'debug': True, 'port': 8080},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
json_file = JSONConfigFile(
|
||||
str(config_file),
|
||||
@@ -174,7 +178,7 @@ class TestJSONConfigFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_raises_error(self, tmp_path):
|
||||
"""Invalid JSON should raise clear error."""
|
||||
config_file = tmp_path / "invalid.json"
|
||||
config_file = tmp_path / 'invalid.json'
|
||||
|
||||
# Write invalid JSON (missing closing brace)
|
||||
config_file.write_text('{"name": "test", "unclosed": ')
|
||||
@@ -184,13 +188,13 @@ class TestJSONConfigFile:
|
||||
template_data={'name': 'default'},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Syntax error"):
|
||||
with pytest.raises(Exception, match='Syntax error'):
|
||||
await json_file.load(completion=False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_json_creates_from_template(self, tmp_path):
|
||||
"""Missing JSON file should be created from template."""
|
||||
config_file = tmp_path / "new_config.json"
|
||||
config_file = tmp_path / 'new_config.json'
|
||||
|
||||
json_file = JSONConfigFile(
|
||||
str(config_file),
|
||||
@@ -205,7 +209,7 @@ class TestJSONConfigFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_save(self, tmp_path):
|
||||
"""JSON config can be saved."""
|
||||
config_file = tmp_path / "save_test.json"
|
||||
config_file = tmp_path / 'save_test.json'
|
||||
|
||||
json_file = JSONConfigFile(
|
||||
str(config_file),
|
||||
@@ -226,7 +230,7 @@ class TestConfigManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_manager_load(self, tmp_path):
|
||||
"""ConfigManager loads config correctly."""
|
||||
config_file = tmp_path / "manager_test.yaml"
|
||||
config_file = tmp_path / 'manager_test.yaml'
|
||||
config_file.write_text('name: managed_app\nversion: "1.0"\n')
|
||||
|
||||
yaml_file = YAMLConfigFile(
|
||||
@@ -243,7 +247,7 @@ class TestConfigManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_manager_dump(self, tmp_path):
|
||||
"""ConfigManager can dump config."""
|
||||
config_file = tmp_path / "dump_test.yaml"
|
||||
config_file = tmp_path / 'dump_test.yaml'
|
||||
|
||||
yaml_file = YAMLConfigFile(
|
||||
str(config_file),
|
||||
@@ -260,7 +264,7 @@ class TestConfigManager:
|
||||
|
||||
def test_config_manager_dump_sync(self, tmp_path):
|
||||
"""ConfigManager can dump config synchronously."""
|
||||
config_file = tmp_path / "sync_dump.yaml"
|
||||
config_file = tmp_path / 'sync_dump.yaml'
|
||||
|
||||
yaml_file = YAMLConfigFile(
|
||||
str(config_file),
|
||||
@@ -280,7 +284,7 @@ class TestConfigExists:
|
||||
|
||||
def test_yaml_exists_true(self, tmp_path):
|
||||
"""exists() returns True for existing file."""
|
||||
config_file = tmp_path / "exists.yaml"
|
||||
config_file = tmp_path / 'exists.yaml'
|
||||
config_file.write_text('name: test')
|
||||
|
||||
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
||||
@@ -288,14 +292,14 @@ class TestConfigExists:
|
||||
|
||||
def test_yaml_exists_false(self, tmp_path):
|
||||
"""exists() returns False for missing file."""
|
||||
config_file = tmp_path / "missing.yaml"
|
||||
config_file = tmp_path / 'missing.yaml'
|
||||
|
||||
yaml_file = YAMLConfigFile(str(config_file), template_data={})
|
||||
assert yaml_file.exists() is False
|
||||
|
||||
def test_json_exists_true(self, tmp_path):
|
||||
"""exists() returns True for existing JSON file."""
|
||||
config_file = tmp_path / "exists.json"
|
||||
config_file = tmp_path / 'exists.json'
|
||||
config_file.write_text('{}')
|
||||
|
||||
json_file = JSONConfigFile(str(config_file), template_data={})
|
||||
@@ -303,7 +307,7 @@ class TestConfigExists:
|
||||
|
||||
def test_json_exists_false(self, tmp_path):
|
||||
"""exists() returns False for missing JSON file."""
|
||||
config_file = tmp_path / "missing.json"
|
||||
config_file = tmp_path / 'missing.json'
|
||||
|
||||
json_file = JSONConfigFile(str(config_file), template_data={})
|
||||
assert json_file.exists() is False
|
||||
assert json_file.exists() is False
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Core module unit tests."""
|
||||
"""Core module unit tests."""
|
||||
|
||||
@@ -4,6 +4,7 @@ Tests cover:
|
||||
- _get_positive_int_config() validation
|
||||
- _get_positive_float_config() validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
@@ -188,4 +189,4 @@ class TestGetPositiveFloatConfig:
|
||||
result = app._get_positive_float_config('not-a-number', default=1.5, name='test.config')
|
||||
|
||||
assert result == 1.5
|
||||
mock_logger.warning.assert_called_once()
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
@@ -27,6 +27,7 @@ class TestCheckDeps:
|
||||
from langbot.pkg.core.bootutils.deps import check_deps
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||
|
||||
assert result == []
|
||||
@@ -46,6 +47,7 @@ class TestCheckDeps:
|
||||
from langbot.pkg.core.bootutils.deps import check_deps
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||
|
||||
assert 'requests' in result
|
||||
@@ -61,6 +63,7 @@ class TestCheckDeps:
|
||||
from langbot.pkg.core.bootutils.deps import check_deps, required_deps
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(check_deps())
|
||||
|
||||
# Should include all required_deps keys
|
||||
@@ -107,6 +110,7 @@ class TestPrecheckPluginDeps:
|
||||
with patch('os.path.exists', return_value=False):
|
||||
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
||||
|
||||
mock_install.assert_not_called()
|
||||
@@ -129,6 +133,7 @@ class TestPrecheckPluginDeps:
|
||||
with patch('os.listdir', side_effect=mock_listdir):
|
||||
with patch('langbot.pkg.core.bootutils.deps.pkgmgr.install_requirements') as mock_install:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(precheck_plugin_deps())
|
||||
|
||||
mock_install.assert_called_once_with('plugins/plugin1/requirements.txt', extra_params=[])
|
||||
|
||||
@@ -7,6 +7,7 @@ Tests cover:
|
||||
- Dict type skipping
|
||||
- Missing key creation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -248,15 +249,8 @@ class TestApplyEnvOverridesToConfig:
|
||||
"""Test multiple env vars applied in order."""
|
||||
load_config = get_load_config_module()
|
||||
|
||||
cfg = {
|
||||
'system': {'name': 'default', 'enable': True},
|
||||
'concurrency': {'pipeline': 5}
|
||||
}
|
||||
env = {
|
||||
'SYSTEM__NAME': 'custom',
|
||||
'SYSTEM__ENABLE': 'false',
|
||||
'CONCURRENCY__PIPELINE': '10'
|
||||
}
|
||||
cfg = {'system': {'name': 'default', 'enable': True}, 'concurrency': {'pipeline': 5}}
|
||||
env = {'SYSTEM__NAME': 'custom', 'SYSTEM__ENABLE': 'false', 'CONCURRENCY__PIPELINE': '10'}
|
||||
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
@@ -287,4 +281,4 @@ class TestApplyEnvOverridesToConfig:
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
result = load_config._apply_env_overrides_to_config(cfg)
|
||||
|
||||
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||
assert result['api']['extra_webhook_prefix'] == 'https://extra.example.com'
|
||||
|
||||
@@ -175,4 +175,4 @@ class TestPreregisteredStages:
|
||||
pass
|
||||
|
||||
for key in preregistered_stages:
|
||||
assert isinstance(key, str)
|
||||
assert isinstance(key, str)
|
||||
|
||||
@@ -7,6 +7,7 @@ Tests cover:
|
||||
|
||||
Note: Uses import_isolation to break circular import chains.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -19,15 +20,17 @@ from typing import Generator
|
||||
|
||||
class MockLifecycleControlScopeEnum:
|
||||
"""Mock enum value for LifecycleControlScope with .value attribute."""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"LifecycleControlScope.{self.value.upper()}"
|
||||
return f'LifecycleControlScope.{self.value.upper()}'
|
||||
|
||||
|
||||
class MockLifecycleControlScope:
|
||||
"""Mock enum for LifecycleControlScope."""
|
||||
|
||||
APPLICATION = MockLifecycleControlScopeEnum('application')
|
||||
PLATFORM = MockLifecycleControlScopeEnum('platform')
|
||||
PIPELINE = MockLifecycleControlScopeEnum('pipeline')
|
||||
@@ -40,17 +43,17 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
||||
# Mock modules that cause circular imports
|
||||
mock_entities = MagicMock()
|
||||
mock_entities.LifecycleControlScope = MockLifecycleControlScope
|
||||
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
|
||||
mock_importutil = MagicMock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
|
||||
|
||||
|
||||
mock_http_controller = MagicMock()
|
||||
|
||||
|
||||
mock_rag_mgr = MagicMock()
|
||||
|
||||
|
||||
mocks = {
|
||||
'langbot.pkg.core.entities': mock_entities,
|
||||
'langbot.pkg.core.app': mock_app,
|
||||
@@ -58,26 +61,26 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
||||
'langbot.pkg.rag.knowledge.kbmgr': mock_rag_mgr,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
}
|
||||
|
||||
|
||||
# Save original state
|
||||
saved = {}
|
||||
for name in mocks:
|
||||
if name in sys.modules:
|
||||
saved[name] = sys.modules[name]
|
||||
|
||||
|
||||
# Clear taskmgr to force re-import
|
||||
taskmgr_name = 'langbot.pkg.core.taskmgr'
|
||||
if taskmgr_name in sys.modules:
|
||||
saved[taskmgr_name] = sys.modules[taskmgr_name]
|
||||
|
||||
|
||||
try:
|
||||
# Apply mocks
|
||||
for name, module in mocks.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
# Clear taskmgr
|
||||
sys.modules.pop(taskmgr_name, None)
|
||||
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore
|
||||
@@ -86,7 +89,7 @@ def isolated_taskmgr_import() -> Generator[None, None, None]:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
if taskmgr_name in saved:
|
||||
sys.modules[taskmgr_name] = saved[taskmgr_name]
|
||||
else:
|
||||
@@ -97,6 +100,7 @@ def get_taskmgr_classes():
|
||||
"""Get TaskContext, TaskWrapper, AsyncTaskManager classes."""
|
||||
with isolated_taskmgr_import():
|
||||
from langbot.pkg.core.taskmgr import TaskContext, TaskWrapper, AsyncTaskManager
|
||||
|
||||
return TaskContext, TaskWrapper, AsyncTaskManager
|
||||
|
||||
|
||||
@@ -194,9 +198,10 @@ class TestTaskContext:
|
||||
"""Test TaskContext.placeholder() returns singleton."""
|
||||
with isolated_taskmgr_import():
|
||||
from langbot.pkg.core.taskmgr import TaskContext
|
||||
|
||||
|
||||
# Reset global placeholder
|
||||
import langbot.pkg.core.taskmgr as taskmgr_module
|
||||
|
||||
taskmgr_module.placeholder_context = None
|
||||
|
||||
ctx1 = TaskContext.placeholder()
|
||||
@@ -269,7 +274,8 @@ class TestTaskWrapper:
|
||||
return 'result'
|
||||
|
||||
wrapper = TaskWrapper(
|
||||
mock_app, immediate_coro(),
|
||||
mock_app,
|
||||
immediate_coro(),
|
||||
name='test_task',
|
||||
label='Test Task',
|
||||
)
|
||||
@@ -414,7 +420,7 @@ class TestAsyncTaskManager:
|
||||
async def test_cancel_by_scope(self):
|
||||
"""Test cancel_by_scope cancels matching tasks."""
|
||||
_, _, AsyncTaskManager = get_taskmgr_classes()
|
||||
|
||||
|
||||
mock_app = create_mock_app()
|
||||
manager = AsyncTaskManager(mock_app)
|
||||
|
||||
@@ -422,16 +428,10 @@ class TestAsyncTaskManager:
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Create task with APPLICATION scope
|
||||
w1 = manager.create_task(
|
||||
long_coro(),
|
||||
scopes=[MockLifecycleControlScope.APPLICATION]
|
||||
)
|
||||
w1 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.APPLICATION])
|
||||
|
||||
# Create task with different scope
|
||||
w2 = manager.create_task(
|
||||
long_coro(),
|
||||
scopes=[MockLifecycleControlScope.PIPELINE]
|
||||
)
|
||||
w2 = manager.create_task(long_coro(), scopes=[MockLifecycleControlScope.PIPELINE])
|
||||
|
||||
manager.cancel_by_scope(MockLifecycleControlScope.APPLICATION)
|
||||
|
||||
|
||||
@@ -15,68 +15,68 @@ class TestI18nString:
|
||||
|
||||
def test_create_with_english_only(self):
|
||||
"""Create I18nString with only English."""
|
||||
i18n = I18nString(en_US="Hello")
|
||||
i18n = I18nString(en_US='Hello')
|
||||
|
||||
assert i18n.en_US == "Hello"
|
||||
assert i18n.en_US == 'Hello'
|
||||
assert i18n.zh_Hans is None
|
||||
|
||||
def test_create_with_multiple_languages(self):
|
||||
"""Create I18nString with multiple languages."""
|
||||
i18n = I18nString(
|
||||
en_US="Hello",
|
||||
zh_Hans="你好",
|
||||
zh_Hant="你好",
|
||||
ja_JP="こんにちは",
|
||||
en_US='Hello',
|
||||
zh_Hans='你好',
|
||||
zh_Hant='你好',
|
||||
ja_JP='こんにちは',
|
||||
)
|
||||
|
||||
assert i18n.en_US == "Hello"
|
||||
assert i18n.zh_Hans == "你好"
|
||||
assert i18n.zh_Hant == "你好"
|
||||
assert i18n.ja_JP == "こんにちは"
|
||||
assert i18n.en_US == 'Hello'
|
||||
assert i18n.zh_Hans == '你好'
|
||||
assert i18n.zh_Hant == '你好'
|
||||
assert i18n.ja_JP == 'こんにちは'
|
||||
|
||||
def test_to_dict_with_english_only(self):
|
||||
"""to_dict returns only non-None fields."""
|
||||
i18n = I18nString(en_US="Hello")
|
||||
i18n = I18nString(en_US='Hello')
|
||||
|
||||
result = i18n.to_dict()
|
||||
|
||||
assert result == {"en_US": "Hello"}
|
||||
assert result == {'en_US': 'Hello'}
|
||||
|
||||
def test_to_dict_with_multiple_languages(self):
|
||||
"""to_dict returns all non-None fields."""
|
||||
i18n = I18nString(
|
||||
en_US="Hello",
|
||||
zh_Hans="你好",
|
||||
en_US='Hello',
|
||||
zh_Hans='你好',
|
||||
)
|
||||
|
||||
result = i18n.to_dict()
|
||||
|
||||
assert result == {"en_US": "Hello", "zh_Hans": "你好"}
|
||||
assert result == {'en_US': 'Hello', 'zh_Hans': '你好'}
|
||||
|
||||
def test_to_dict_excludes_none(self):
|
||||
"""to_dict excludes None values."""
|
||||
i18n = I18nString(
|
||||
en_US="Hello",
|
||||
en_US='Hello',
|
||||
zh_Hans=None,
|
||||
ja_JP="こんにちは",
|
||||
ja_JP='こんにちは',
|
||||
)
|
||||
|
||||
result = i18n.to_dict()
|
||||
|
||||
assert "zh_Hans" not in result
|
||||
assert "en_US" in result
|
||||
assert "ja_JP" in result
|
||||
assert 'zh_Hans' not in result
|
||||
assert 'en_US' in result
|
||||
assert 'ja_JP' in result
|
||||
|
||||
def test_to_dict_all_languages(self):
|
||||
"""to_dict with all supported languages."""
|
||||
i18n = I18nString(
|
||||
en_US="Hello",
|
||||
zh_Hans="你好",
|
||||
zh_Hant="你好",
|
||||
ja_JP="こんにちは",
|
||||
th_TH="สวัสดี",
|
||||
vi_VN="Xin chào",
|
||||
es_ES="Hola",
|
||||
en_US='Hello',
|
||||
zh_Hans='你好',
|
||||
zh_Hant='你好',
|
||||
ja_JP='こんにちは',
|
||||
th_TH='สวัสดี',
|
||||
vi_VN='Xin chào',
|
||||
es_ES='Hola',
|
||||
)
|
||||
|
||||
result = i18n.to_dict()
|
||||
@@ -92,30 +92,30 @@ class TestMetadata:
|
||||
from langbot.pkg.discover.engine import I18nString
|
||||
|
||||
metadata = Metadata(
|
||||
name="test-component",
|
||||
label=I18nString(en_US="Test Component"),
|
||||
name='test-component',
|
||||
label=I18nString(en_US='Test Component'),
|
||||
)
|
||||
|
||||
assert metadata.name == "test-component"
|
||||
assert metadata.label.en_US == "Test Component"
|
||||
assert metadata.name == 'test-component'
|
||||
assert metadata.label.en_US == 'Test Component'
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Create Metadata with all optional fields."""
|
||||
from langbot.pkg.discover.engine import I18nString
|
||||
|
||||
metadata = Metadata(
|
||||
name="test-component",
|
||||
label=I18nString(en_US="Test"),
|
||||
description=I18nString(en_US="A test component"),
|
||||
version="1.0.0",
|
||||
icon="test-icon",
|
||||
author="Test Author",
|
||||
repository="https://github.com/test/repo",
|
||||
name='test-component',
|
||||
label=I18nString(en_US='Test'),
|
||||
description=I18nString(en_US='A test component'),
|
||||
version='1.0.0',
|
||||
icon='test-icon',
|
||||
author='Test Author',
|
||||
repository='https://github.com/test/repo',
|
||||
)
|
||||
|
||||
assert metadata.version == "1.0.0"
|
||||
assert metadata.icon == "test-icon"
|
||||
assert metadata.author == "Test Author"
|
||||
assert metadata.version == '1.0.0'
|
||||
assert metadata.icon == 'test-icon'
|
||||
assert metadata.author == 'Test Author'
|
||||
|
||||
|
||||
class TestComponentManifest:
|
||||
|
||||
@@ -7,6 +7,7 @@ Tests cover:
|
||||
|
||||
Note: Uses import isolation to break circular import chains.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
@@ -86,6 +87,7 @@ def get_database_module():
|
||||
"""Get database module with import isolation."""
|
||||
with isolated_database_import():
|
||||
from langbot.pkg.persistence import database
|
||||
|
||||
return database
|
||||
|
||||
|
||||
@@ -198,4 +200,4 @@ class TestManagerClassDecorator:
|
||||
# Create instance to test method (with mock app)
|
||||
mock_app = Mock()
|
||||
instance = ManagerWithMethods(mock_app)
|
||||
assert instance.custom_method() == 'test_value'
|
||||
assert instance.custom_method() == 'test_value'
|
||||
|
||||
@@ -4,6 +4,7 @@ Tests cover:
|
||||
- execute_async() with mock database
|
||||
- get_db_engine() with mock database manager
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -85,7 +86,7 @@ class TestExecuteAsync:
|
||||
mock_db.get_engine = Mock(return_value=mock_engine)
|
||||
mgr.db = mock_db
|
||||
|
||||
result = await mgr.execute_async(sqlalchemy.text("SELECT 1"))
|
||||
result = await mgr.execute_async(sqlalchemy.text('SELECT 1'))
|
||||
|
||||
# Verify result is the same object returned by execute
|
||||
assert result is mock_result
|
||||
@@ -152,4 +153,4 @@ class TestSerializeModelEdgeCases:
|
||||
result = mgr.serialize_model(SimpleModel, instance, masked_columns=['id', 'name'])
|
||||
|
||||
# Result should be empty dict when all columns masked
|
||||
assert result == {}
|
||||
assert result == {}
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests cover:
|
||||
- datetime conversion to isoformat
|
||||
- masked_columns exclusion
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
@@ -49,7 +49,7 @@ class TestPendingMessage:
|
||||
"""PendingMessage should be created with correct fields."""
|
||||
aggregator = get_aggregator_module()
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -88,7 +88,7 @@ class TestSessionBuffer:
|
||||
"""SessionBuffer should accept initial messages."""
|
||||
aggregator = get_aggregator_module()
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
|
||||
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
|
||||
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain1 = text_chain("hello")
|
||||
chain2 = text_chain("world")
|
||||
chain1 = text_chain('hello')
|
||||
chain2 = text_chain('world')
|
||||
event = friend_message_event(chain1)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
|
||||
|
||||
# Should contain both messages with separator
|
||||
merged_str = str(merged.message_chain)
|
||||
assert "hello" in merged_str
|
||||
assert "world" in merged_str
|
||||
assert 'hello' in merged_str
|
||||
assert 'world' in merged_str
|
||||
|
||||
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
|
||||
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
|
||||
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain1 = text_chain("first")
|
||||
chain2 = text_chain("second")
|
||||
chain1 = text_chain('first')
|
||||
chain2 = text_chain('second')
|
||||
event = friend_message_event(chain1)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain = text_chain("hello")
|
||||
chain = text_chain('hello')
|
||||
event = friend_message_event(chain)
|
||||
adapter = mock_adapter()
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from tests.factories import FakeApp
|
||||
|
||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
@@ -36,9 +37,11 @@ def mock_circular_import_chain():
|
||||
# Create a default runner that yields a simple response
|
||||
class DefaultRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='fake response')
|
||||
|
||||
@@ -70,9 +73,12 @@ def mock_event_ctx():
|
||||
@pytest.fixture
|
||||
def set_runner():
|
||||
"""Factory fixture to set a custom runner for tests."""
|
||||
|
||||
def _set_runner(runner_class):
|
||||
import sys
|
||||
|
||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
|
||||
|
||||
return _set_runner
|
||||
|
||||
|
||||
@@ -87,6 +93,7 @@ def get_chat_handler():
|
||||
global _chat_handler_module
|
||||
if _chat_handler_module is None:
|
||||
from importlib import import_module
|
||||
|
||||
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
|
||||
return _chat_handler_module
|
||||
|
||||
@@ -96,12 +103,14 @@ def get_entities():
|
||||
global _entities_module
|
||||
if _entities_module is None:
|
||||
from importlib import import_module
|
||||
|
||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||
return _entities_module
|
||||
|
||||
|
||||
# ============== REAL ChatMessageHandler Tests ==============
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestChatMessageHandlerReal:
|
||||
"""Tests for real ChatMessageHandler class."""
|
||||
@@ -188,9 +197,11 @@ class TestChatMessageHandlerReal:
|
||||
|
||||
class QuickRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='ok')
|
||||
|
||||
@@ -222,9 +233,11 @@ class TestChatMessageHandlerReal:
|
||||
|
||||
class SingleRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='response')
|
||||
|
||||
@@ -262,9 +275,11 @@ class TestChatHandlerStreaming:
|
||||
|
||||
class StreamRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
yield MessageChunk(role='assistant', content='Hello', is_final=False)
|
||||
yield MessageChunk(role='assistant', content=' World', is_final=True)
|
||||
@@ -303,14 +318,19 @@ class TestChatHandlerExceptions:
|
||||
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
||||
},
|
||||
}
|
||||
|
||||
class FailingRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
raise ValueError('API error')
|
||||
yield
|
||||
@@ -346,14 +366,19 @@ class TestChatHandlerExceptions:
|
||||
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'show-error'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
||||
},
|
||||
}
|
||||
|
||||
class ErrorRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
raise ValueError('Custom error')
|
||||
yield
|
||||
@@ -386,14 +411,19 @@ class TestChatHandlerExceptions:
|
||||
|
||||
query.pipeline_config = {
|
||||
'output': {'misc': {'exception-handling': 'hide'}},
|
||||
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
|
||||
},
|
||||
}
|
||||
|
||||
class HideErrorRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
raise RuntimeError('hidden')
|
||||
yield
|
||||
@@ -433,4 +463,4 @@ class TestChatHandlerHelper:
|
||||
chat = get_chat_handler()
|
||||
handler = chat.ChatMessageHandler(fake_app)
|
||||
result = handler.cut_str('first line\nsecond line')
|
||||
assert '...' in result
|
||||
assert '...' in result
|
||||
|
||||
@@ -67,7 +67,11 @@ def make_pipeline_config(**overrides):
|
||||
for key, value in overrides.items():
|
||||
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
|
||||
for sub_key, sub_value in value.items():
|
||||
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
|
||||
if (
|
||||
sub_key in base_config[key]
|
||||
and isinstance(base_config[key][sub_key], dict)
|
||||
and isinstance(sub_value, dict)
|
||||
):
|
||||
base_config[key][sub_key].update(sub_value)
|
||||
else:
|
||||
base_config[key][sub_key] = sub_value
|
||||
@@ -141,7 +145,7 @@ class TestPreContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello world")
|
||||
query = text_query('hello world')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -163,7 +167,7 @@ class TestPreContentFilter:
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Empty message chain
|
||||
query = text_query("")
|
||||
query = text_query('')
|
||||
query.message_chain = platform_message.MessageChain([])
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
@@ -185,7 +189,7 @@ class TestPreContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query(" ") # Only whitespace
|
||||
query = text_query(' ') # Only whitespace
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -234,7 +238,7 @@ class TestPreContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello world")
|
||||
query = text_query('hello world')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -266,7 +270,7 @@ class TestContentIgnoreFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("/help me")
|
||||
query = text_query('/help me')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -294,7 +298,7 @@ class TestContentIgnoreFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("http://example.com")
|
||||
query = text_query('http://example.com')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -322,7 +326,7 @@ class TestContentIgnoreFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("normal message")
|
||||
query = text_query('normal message')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -343,7 +347,7 @@ class TestContentIgnoreFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("/help me")
|
||||
query = text_query('/help me')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
@@ -368,12 +372,10 @@ class TestPostContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
# Add a response message
|
||||
query.resp_messages = [
|
||||
provider_message.Message(role='assistant', content='Hello back!')
|
||||
]
|
||||
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
|
||||
|
||||
result = await stage.process(query, 'PostContentFilterStage')
|
||||
|
||||
@@ -398,11 +400,9 @@ class TestPostContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_messages = [
|
||||
provider_message.Message(role='assistant', content='Response')
|
||||
]
|
||||
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
|
||||
|
||||
result = await stage.process(query, 'PostContentFilterStage')
|
||||
|
||||
@@ -422,7 +422,7 @@ class TestPostContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
# Non-string content - use model_construct to bypass validation
|
||||
# The actual content type could be a list of ContentElement objects
|
||||
@@ -450,11 +450,9 @@ class TestPostContentFilter:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_messages = [
|
||||
provider_message.Message(role='assistant', content='')
|
||||
]
|
||||
query.resp_messages = [provider_message.Message(role='assistant', content='')]
|
||||
|
||||
result = await stage.process(query, 'PostContentFilterStage')
|
||||
|
||||
@@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
|
||||
@@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("normal message without prefix")
|
||||
query = text_query('normal message without prefix')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
result = await stage.process(query, 'PreContentFilterStage')
|
||||
|
||||
@@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query
|
||||
|
||||
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mock_circular_import_chain():
|
||||
"""
|
||||
@@ -56,6 +57,7 @@ def mock_event_ctx():
|
||||
@pytest.fixture
|
||||
def mock_execute_factory():
|
||||
"""Factory fixture to create mock cmd_mgr.execute generators."""
|
||||
|
||||
def _create_execute(
|
||||
text: str | None = 'ok',
|
||||
error: str | None = None,
|
||||
@@ -71,7 +73,9 @@ def mock_execute_factory():
|
||||
ret.image_base64 = image_base64
|
||||
ret.file_url = file_url
|
||||
yield ret
|
||||
|
||||
return mock_execute
|
||||
|
||||
return _create_execute
|
||||
|
||||
|
||||
@@ -86,6 +90,7 @@ def get_command_handler():
|
||||
global _command_handler_module
|
||||
if _command_handler_module is None:
|
||||
from importlib import import_module
|
||||
|
||||
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
|
||||
return _command_handler_module
|
||||
|
||||
@@ -95,12 +100,14 @@ def get_entities():
|
||||
global _entities_module
|
||||
if _entities_module is None:
|
||||
from importlib import import_module
|
||||
|
||||
_entities_module = import_module('langbot.pkg.pipeline.entities')
|
||||
return _entities_module
|
||||
|
||||
|
||||
# ============== REAL CommandHandler Tests ==============
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('mock_circular_import_chain')
|
||||
class TestCommandHandlerReal:
|
||||
"""Tests for real CommandHandler class."""
|
||||
@@ -127,6 +134,7 @@ class TestCommandHandlerReal:
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
executed_commands = []
|
||||
|
||||
async def track_execute(command_text, full_command_text, query, session):
|
||||
executed_commands.append(command_text)
|
||||
ret = Mock()
|
||||
@@ -334,8 +342,7 @@ class TestCommandHandlerReal:
|
||||
command = get_command_handler()
|
||||
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
fake_app.cmd_mgr.execute = mock_execute_factory(
|
||||
text='Here is the image:',
|
||||
image_url='https://example.com/image.png'
|
||||
text='Here is the image:', image_url='https://example.com/image.png'
|
||||
)
|
||||
|
||||
handler = command.CommandHandler(fake_app)
|
||||
@@ -393,4 +400,4 @@ class TestCommandHandlerHelper:
|
||||
command = get_command_handler()
|
||||
handler = command.CommandHandler(fake_app)
|
||||
result = handler.cut_str('first line\nsecond line')
|
||||
assert '...' in result
|
||||
assert '...' in result
|
||||
|
||||
@@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = [
|
||||
platform_message.MessageChain([platform_message.Plain(text="very long response")])
|
||||
]
|
||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
@@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = [
|
||||
platform_message.MessageChain([platform_message.Plain(text="short response")])
|
||||
]
|
||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
@@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
# Non-Plain component (Image)
|
||||
query.resp_message_chain = [
|
||||
platform_message.MessageChain([
|
||||
platform_message.Plain(text="short"),
|
||||
platform_message.Image(url="https://example.com/img.png")
|
||||
])
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
|
||||
)
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
@@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess:
|
||||
stage = longtext.LongTextProcessStage(app)
|
||||
stage.strategy_impl = AsyncMock()
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess:
|
||||
assert result.new_query is query
|
||||
stage.strategy_impl.process.assert_not_called()
|
||||
|
||||
|
||||
class TestForwardStrategy:
|
||||
"""Tests for ForwardComponentStrategy."""
|
||||
|
||||
@@ -260,7 +256,7 @@ class TestForwardStrategy:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
# Create a mock adapter with bot_account_id
|
||||
mock_adapter = Mock()
|
||||
@@ -268,10 +264,8 @@ class TestForwardStrategy:
|
||||
query.adapter = mock_adapter
|
||||
|
||||
# Long text exceeding threshold
|
||||
long_text = "This is a very long response that exceeds the threshold"
|
||||
query.resp_message_chain = [
|
||||
platform_message.MessageChain([platform_message.Plain(text=long_text)])
|
||||
]
|
||||
long_text = 'This is a very long response that exceeds the threshold'
|
||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
@@ -297,13 +291,13 @@ class TestForwardStrategy:
|
||||
|
||||
await strat.initialize()
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = make_longtext_config()
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.bot_account_id = '12345'
|
||||
query.adapter = mock_adapter
|
||||
|
||||
components = await strat.process("test message", query)
|
||||
components = await strat.process('test message', query)
|
||||
|
||||
assert len(components) == 1
|
||||
assert isinstance(components[0], platform_message.Forward)
|
||||
@@ -326,14 +320,12 @@ class TestLongTextThreshold:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
|
||||
# Text below threshold
|
||||
short_text = "x" * (threshold - 1)
|
||||
query.resp_message_chain = [
|
||||
platform_message.MessageChain([platform_message.Plain(text=short_text)])
|
||||
]
|
||||
short_text = 'x' * (threshold - 1)
|
||||
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Create query with 3 messages (within limit)
|
||||
query = text_query("current message")
|
||||
query = text_query('current message')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
|
||||
|
||||
# Create query with many messages exceeding limit
|
||||
# 7 messages = 3 full rounds + 1 current user
|
||||
query = text_query("current message")
|
||||
query = text_query('current message')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = []
|
||||
|
||||
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='hello'),
|
||||
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("current")
|
||||
query = text_query('current')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='user1'),
|
||||
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("current")
|
||||
query = text_query('current')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='old1'),
|
||||
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
|
||||
trun = trun_cls(app)
|
||||
break
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = make_truncate_config(max_round=3)
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='m1'),
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = text_query("hello world")
|
||||
query = text_query('hello world')
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = text_query("test message")
|
||||
query = text_query('test message')
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
@@ -194,13 +194,16 @@ class TestPreProcessorImageSegment:
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
# Image query with base64
|
||||
query = image_query(text="look at this", url=None)
|
||||
query = image_query(text='look at this', url=None)
|
||||
# Set base64 on the image component
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
chain = platform_message.MessageChain([
|
||||
platform_message.Plain(text="look at this"),
|
||||
platform_message.Image(base64="data:image/png;base64,abc123"),
|
||||
])
|
||||
|
||||
chain = platform_message.MessageChain(
|
||||
[
|
||||
platform_message.Plain(text='look at this'),
|
||||
platform_message.Image(base64='data:image/png;base64,abc123'),
|
||||
]
|
||||
)
|
||||
query.message_chain = chain
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
@@ -238,7 +241,7 @@ class TestPreProcessorImageSegment:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = image_query(text="describe this")
|
||||
query = image_query(text='describe this')
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
@@ -276,7 +279,7 @@ class TestPreProcessorModelSelection:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
# Set pipeline config with primary model
|
||||
query.pipeline_config = {
|
||||
@@ -335,7 +338,7 @@ class TestPreProcessorModelSelection:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
|
||||
query.pipeline_config = {
|
||||
'ai': {
|
||||
@@ -384,7 +387,7 @@ class TestPreProcessorVariables:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = text_query("hello", sender_id=67890)
|
||||
query = text_query('hello', sender_id=67890)
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
@@ -421,7 +424,7 @@ class TestPreProcessorVariables:
|
||||
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
|
||||
|
||||
stage = preproc.PreProcessor(app)
|
||||
query = group_text_query("hello", group_id=99999)
|
||||
query = group_text_query('hello', group_id=99999)
|
||||
|
||||
result = await stage.process(query, 'PreProcessor')
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
|
||||
'safety': {
|
||||
'rate-limit': {
|
||||
'window-length': 60, # 60 seconds window
|
||||
'limitation': 10, # 10 requests per window
|
||||
'limitation': 10, # 10 requests per window
|
||||
'strategy': 'drop',
|
||||
}
|
||||
}
|
||||
@@ -75,11 +75,9 @@ class TestFixedWindowAlgo:
|
||||
# Make requests within limit
|
||||
for i in range(10):
|
||||
result = await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
|
||||
)
|
||||
assert result is True, f"Request {i+1} should be allowed"
|
||||
assert result is True, f'Request {i + 1} should be allowed'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||
@@ -91,20 +89,12 @@ class TestFixedWindowAlgo:
|
||||
|
||||
# Exhaust the limit
|
||||
for i in range(10):
|
||||
await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
||||
|
||||
# Next request should be denied
|
||||
result = await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
||||
|
||||
assert result is False, "Request exceeding limit should be denied"
|
||||
assert result is False, 'Request exceeding limit should be denied'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||
@@ -116,20 +106,14 @@ class TestFixedWindowAlgo:
|
||||
|
||||
# Exhaust limit for session 1
|
||||
for i in range(10):
|
||||
await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'session1'
|
||||
)
|
||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
|
||||
|
||||
# Session 2 should still have its own limit
|
||||
result = await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'session2'
|
||||
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
|
||||
)
|
||||
|
||||
assert result is True, "Different session should have independent limit"
|
||||
assert result is True, 'Different session should have independent limit'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
|
||||
@@ -150,19 +134,11 @@ class TestFixedWindowAlgo:
|
||||
await algo.initialize()
|
||||
|
||||
# First request allowed
|
||||
result1 = await algo.require_access(
|
||||
sample_query,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
|
||||
assert result1 is True
|
||||
|
||||
# Second request denied
|
||||
result2 = await algo.require_access(
|
||||
sample_query,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
|
||||
assert result2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -174,11 +150,7 @@ class TestFixedWindowAlgo:
|
||||
await algo.initialize()
|
||||
|
||||
# First request creates container
|
||||
await algo.require_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
||||
|
||||
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
|
||||
expected_key = 'LauncherTypes.PERSON_12345'
|
||||
@@ -230,7 +202,7 @@ class TestFixedWindowAlgo:
|
||||
|
||||
# New request should be allowed (new window)
|
||||
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
|
||||
assert result is True, "New window should allow new requests"
|
||||
assert result is True, 'New window should allow new requests'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
|
||||
@@ -256,29 +228,21 @@ class TestFixedWindowAlgo:
|
||||
|
||||
# First request allowed
|
||||
start_time = time.time()
|
||||
result1 = await algo.require_access(
|
||||
sample_query,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'wait_test'
|
||||
)
|
||||
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
||||
assert result1 is True
|
||||
|
||||
# Exhaust limit
|
||||
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
||||
|
||||
# Third request should wait and then succeed
|
||||
result3 = await algo.require_access(
|
||||
sample_query,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'wait_test'
|
||||
)
|
||||
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result3 is True, "After wait, request should succeed"
|
||||
assert result3 is True, 'After wait, request should succeed'
|
||||
# Should have waited approximately until next window
|
||||
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
|
||||
# Note: This is a timing-sensitive test, so we use a generous tolerance
|
||||
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
|
||||
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
|
||||
@@ -289,11 +253,7 @@ class TestFixedWindowAlgo:
|
||||
await algo.initialize()
|
||||
|
||||
# release_access is empty in current implementation
|
||||
await algo.release_access(
|
||||
sample_query_with_rate_limit,
|
||||
provider_session.LauncherTypes.PERSON,
|
||||
'12345'
|
||||
)
|
||||
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
|
||||
|
||||
# Should not raise or change state
|
||||
assert 'person_12345' not in algo.containers
|
||||
|
||||
@@ -55,7 +55,7 @@ def make_session():
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
use_prompt_name="default",
|
||||
use_prompt_name='default',
|
||||
using_conversation=None,
|
||||
conversations=[],
|
||||
)
|
||||
@@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_messages = [
|
||||
platform_message.MessageChain([platform_message.Plain(text="response")])
|
||||
]
|
||||
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
|
||||
query.resp_message_chain = []
|
||||
|
||||
results = []
|
||||
@@ -125,7 +123,7 @@ class TestResponseWrapperCommand:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -133,7 +131,7 @@ class TestResponseWrapperCommand:
|
||||
command_resp = Mock()
|
||||
command_resp.role = 'command'
|
||||
command_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
|
||||
)
|
||||
query.resp_messages = [command_resp]
|
||||
|
||||
@@ -163,7 +161,7 @@ class TestResponseWrapperPlugin:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -171,7 +169,7 @@ class TestResponseWrapperPlugin:
|
||||
plugin_resp = Mock()
|
||||
plugin_resp.role = 'plugin'
|
||||
plugin_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
|
||||
)
|
||||
query.resp_messages = [plugin_resp]
|
||||
|
||||
@@ -211,17 +209,17 @@ class TestResponseWrapperAssistant:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
# Create assistant response with content
|
||||
assistant_resp = Mock()
|
||||
assistant_resp.role = 'assistant'
|
||||
assistant_resp.content = "Hello back!"
|
||||
assistant_resp.content = 'Hello back!'
|
||||
assistant_resp.tool_calls = None
|
||||
assistant_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
|
||||
)
|
||||
query.resp_messages = [assistant_resp]
|
||||
|
||||
@@ -247,7 +245,7 @@ class TestResponseWrapperAssistant:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -292,7 +290,7 @@ class TestResponseWrapperAssistant:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
@@ -303,10 +301,10 @@ class TestResponseWrapperAssistant:
|
||||
|
||||
assistant_resp = Mock()
|
||||
assistant_resp.role = 'assistant'
|
||||
assistant_resp.content = "Processing..."
|
||||
assistant_resp.content = 'Processing...'
|
||||
assistant_resp.tool_calls = [mock_tool_call]
|
||||
assistant_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
|
||||
)
|
||||
query.resp_messages = [assistant_resp]
|
||||
|
||||
@@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
# Create assistant response with content
|
||||
assistant_resp = Mock()
|
||||
assistant_resp.role = 'assistant'
|
||||
assistant_resp.content = "Hello!"
|
||||
assistant_resp.content = 'Hello!'
|
||||
assistant_resp.tool_calls = None
|
||||
assistant_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
|
||||
)
|
||||
query.resp_messages = [assistant_resp]
|
||||
|
||||
@@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply:
|
||||
app.sess_mgr.get_session = AsyncMock(return_value=session)
|
||||
|
||||
# Mock plugin connector with custom reply
|
||||
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
|
||||
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
|
||||
mock_event_ctx = Mock()
|
||||
mock_event_ctx.is_prevented_default = Mock(return_value=False)
|
||||
mock_event_ctx.event = Mock()
|
||||
@@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
# Create assistant response
|
||||
assistant_resp = Mock()
|
||||
assistant_resp.role = 'assistant'
|
||||
assistant_resp.content = "Default reply"
|
||||
assistant_resp.content = 'Default reply'
|
||||
assistant_resp.tool_calls = None
|
||||
assistant_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
|
||||
)
|
||||
query.resp_messages = [assistant_resp]
|
||||
|
||||
@@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply:
|
||||
assert len(results[0].new_query.resp_message_chain) == 1
|
||||
# Should be the custom chain
|
||||
chain = results[0].new_query.resp_message_chain[0]
|
||||
assert "Custom reply" in str(chain)
|
||||
assert 'Custom reply' in str(chain)
|
||||
|
||||
|
||||
class TestResponseWrapperVariables:
|
||||
@@ -452,7 +450,7 @@ class TestResponseWrapperVariables:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query = text_query('hello')
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
|
||||
@@ -460,10 +458,10 @@ class TestResponseWrapperVariables:
|
||||
# Create assistant response
|
||||
assistant_resp = Mock()
|
||||
assistant_resp.role = 'assistant'
|
||||
assistant_resp.content = "Hello"
|
||||
assistant_resp.content = 'Hello'
|
||||
assistant_resp.tool_calls = None
|
||||
assistant_resp.get_content_platform_message_chain = Mock(
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
|
||||
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
|
||||
)
|
||||
query.resp_messages = [assistant_resp]
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Tests cover:
|
||||
- RAG methods (ingest, retrieve, schema)
|
||||
- Disabled plugin early returns
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -86,16 +87,12 @@ class TestListPlugins:
|
||||
return_value=[
|
||||
{
|
||||
'manifest': {'manifest': {'metadata': {'author': 'a', 'name': 'p1'}}},
|
||||
'components': [
|
||||
{'manifest': {'manifest': {'kind': 'Command'}}}
|
||||
],
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Command'}}}],
|
||||
'debug': False,
|
||||
},
|
||||
{
|
||||
'manifest': {'manifest': {'metadata': {'author': 'b', 'name': 'p2'}}},
|
||||
'components': [
|
||||
{'manifest': {'manifest': {'kind': 'Tool'}}}
|
||||
],
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool'}}}],
|
||||
'debug': False,
|
||||
},
|
||||
]
|
||||
@@ -127,9 +124,7 @@ class TestListPlugins:
|
||||
},
|
||||
]
|
||||
)
|
||||
connector.ap.persistence_mgr.execute_async = AsyncMock(
|
||||
return_value=Mock(__iter__=lambda self: iter([]))
|
||||
)
|
||||
connector.ap.persistence_mgr.execute_async = AsyncMock(return_value=Mock(__iter__=lambda self: iter([])))
|
||||
|
||||
result = await connector.list_plugins()
|
||||
|
||||
@@ -230,7 +225,8 @@ class TestCallParser:
|
||||
)
|
||||
|
||||
connector.handler.parse_document.assert_called_once_with(
|
||||
'author', 'parser',
|
||||
'author',
|
||||
'parser',
|
||||
{'mime_type': 'text/plain', 'filename': 'test.txt'},
|
||||
b'file content',
|
||||
)
|
||||
@@ -251,9 +247,7 @@ class TestRAGMethods:
|
||||
|
||||
result = await connector.call_rag_ingest('author/engine', {'file': 'test.pdf'})
|
||||
|
||||
connector.handler.rag_ingest_document.assert_called_once_with(
|
||||
'author', 'engine', {'file': 'test.pdf'}
|
||||
)
|
||||
connector.handler.rag_ingest_document.assert_called_once_with('author', 'engine', {'file': 'test.pdf'})
|
||||
assert result['status'] == 'success'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,14 +258,16 @@ class TestRAGMethods:
|
||||
|
||||
connector.handler = AsyncMock()
|
||||
connector.handler.retrieve_knowledge = AsyncMock(
|
||||
return_value={'results': [{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}]}
|
||||
return_value={
|
||||
'results': [
|
||||
{'id': 'doc1', 'content': [{'type': 'text', 'text': 'test'}], 'metadata': {}, 'distance': 0.1}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
result = await connector.call_rag_retrieve('author/engine', {'query': 'test'})
|
||||
|
||||
connector.handler.retrieve_knowledge.assert_called_once_with(
|
||||
'author', 'engine', '', {'query': 'test'}
|
||||
)
|
||||
connector.handler.retrieve_knowledge.assert_called_once_with('author', 'engine', '', {'query': 'test'})
|
||||
assert result == {
|
||||
'results': [
|
||||
{
|
||||
@@ -290,9 +286,7 @@ class TestRAGMethods:
|
||||
connector = create_mock_connector()
|
||||
|
||||
connector.handler = AsyncMock()
|
||||
connector.handler.get_rag_creation_schema = AsyncMock(
|
||||
return_value={'properties': {'name': {'type': 'string'}}}
|
||||
)
|
||||
connector.handler.get_rag_creation_schema = AsyncMock(return_value={'properties': {'name': {'type': 'string'}}})
|
||||
|
||||
result = await connector.get_rag_creation_schema('author/engine')
|
||||
|
||||
@@ -326,9 +320,7 @@ class TestRAGMethods:
|
||||
|
||||
await connector.rag_on_kb_create('author/engine', 'kb-uuid', {'model': 'test'})
|
||||
|
||||
connector.handler.rag_on_kb_create.assert_called_once_with(
|
||||
'author', 'engine', 'kb-uuid', {'model': 'test'}
|
||||
)
|
||||
connector.handler.rag_on_kb_create.assert_called_once_with('author', 'engine', 'kb-uuid', {'model': 'test'})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_on_kb_delete(self):
|
||||
@@ -354,9 +346,7 @@ class TestRAGMethods:
|
||||
|
||||
result = await connector.call_rag_delete_document('author/engine', 'doc-uuid', 'kb-uuid')
|
||||
|
||||
connector.handler.rag_delete_document.assert_called_once_with(
|
||||
'author', 'engine', 'doc-uuid', 'kb-uuid'
|
||||
)
|
||||
connector.handler.rag_delete_document.assert_called_once_with('author', 'engine', 'doc-uuid', 'kb-uuid')
|
||||
assert result is True
|
||||
|
||||
|
||||
@@ -446,9 +436,7 @@ class TestGetPluginInfo:
|
||||
connector = create_mock_connector()
|
||||
|
||||
connector.handler = AsyncMock()
|
||||
connector.handler.get_plugin_info = AsyncMock(
|
||||
return_value={'manifest': {'metadata': {'name': 'plugin'}}}
|
||||
)
|
||||
connector.handler.get_plugin_info = AsyncMock(return_value={'manifest': {'metadata': {'name': 'plugin'}}})
|
||||
|
||||
result = await connector.get_plugin_info('author', 'plugin')
|
||||
|
||||
@@ -470,9 +458,7 @@ class TestSetPluginConfig:
|
||||
|
||||
await connector.set_plugin_config('author', 'plugin', {'setting': 'value'})
|
||||
|
||||
connector.handler.set_plugin_config.assert_called_once_with(
|
||||
'author', 'plugin', {'setting': 'value'}
|
||||
)
|
||||
connector.handler.set_plugin_config.assert_called_once_with('author', 'plugin', {'setting': 'value'})
|
||||
|
||||
|
||||
class TestPingPluginRuntime:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Tests cover:
|
||||
- _parse_plugin_id() parsing and validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -6,6 +6,7 @@ Tests cover:
|
||||
- Handling missing requirements.txt
|
||||
- Handling empty/malformed requirements.txt
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import zipfile
|
||||
@@ -82,13 +83,13 @@ class TestExtractDepsMetadata:
|
||||
"""Test that comments and empty lines are filtered."""
|
||||
connector_instance = create_mock_connector()
|
||||
|
||||
requirements = '''# This is a comment
|
||||
requirements = """# This is a comment
|
||||
requests>=2.0
|
||||
|
||||
# Another comment
|
||||
flask==1.0
|
||||
|
||||
numpy'''
|
||||
numpy"""
|
||||
zip_bytes = create_zip_with_requirements(requirements)
|
||||
|
||||
task_context = Mock()
|
||||
@@ -147,9 +148,9 @@ numpy'''
|
||||
"""Test handling requirements.txt with only comments."""
|
||||
connector_instance = create_mock_connector()
|
||||
|
||||
requirements = '''# Comment 1
|
||||
requirements = """# Comment 1
|
||||
# Comment 2
|
||||
# Comment 3'''
|
||||
# Comment 3"""
|
||||
zip_bytes = create_zip_with_requirements(requirements)
|
||||
|
||||
task_context = Mock()
|
||||
|
||||
@@ -40,11 +40,13 @@ class TestHandlerQueryVariables:
|
||||
"""Test set_query_var returns error when query not found."""
|
||||
runtime_handler = make_handler(mock_app)
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
|
||||
'query_id': 'nonexistent-query',
|
||||
'key': 'test_var',
|
||||
'value': 'test_value',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
|
||||
{
|
||||
'query_id': 'nonexistent-query',
|
||||
'key': 'test_var',
|
||||
'value': 'test_value',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code != 0
|
||||
assert 'nonexistent-query' in response.message
|
||||
@@ -58,11 +60,13 @@ class TestHandlerQueryVariables:
|
||||
|
||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value]({
|
||||
'query_id': 'test-query',
|
||||
'key': 'test_var',
|
||||
'value': 'test_value',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.SET_QUERY_VAR.value](
|
||||
{
|
||||
'query_id': 'test-query',
|
||||
'key': 'test_var',
|
||||
'value': 'test_value',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert mock_query.variables['test_var'] == 'test_value'
|
||||
@@ -76,10 +80,12 @@ class TestHandlerQueryVariables:
|
||||
|
||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value]({
|
||||
'query_id': 'test-query',
|
||||
'key': 'existing_var',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VAR.value](
|
||||
{
|
||||
'query_id': 'test-query',
|
||||
'key': 'existing_var',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'value': 'existing_value'}
|
||||
@@ -93,9 +99,11 @@ class TestHandlerQueryVariables:
|
||||
|
||||
mock_app.query_pool.cached_queries['test-query'] = mock_query
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value]({
|
||||
'query_id': 'test-query',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_QUERY_VARS.value](
|
||||
{
|
||||
'query_id': 'test-query',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'vars': mock_query.variables}
|
||||
@@ -108,7 +116,7 @@ class TestHandlerRagErrorResponse:
|
||||
"""Test basic error response creation."""
|
||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||
|
||||
error = Exception("test error")
|
||||
error = Exception('test error')
|
||||
response = _make_rag_error_response(error, 'TestError')
|
||||
|
||||
# ActionResponse is a pydantic model, check message field
|
||||
@@ -120,13 +128,8 @@ class TestHandlerRagErrorResponse:
|
||||
"""Test error response with extra context."""
|
||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||
|
||||
error = ValueError("invalid input")
|
||||
response = _make_rag_error_response(
|
||||
error,
|
||||
'ValidationError',
|
||||
field='name',
|
||||
value='test'
|
||||
)
|
||||
error = ValueError('invalid input')
|
||||
response = _make_rag_error_response(error, 'ValidationError', field='name', value='test')
|
||||
|
||||
assert 'ValidationError' in response.message
|
||||
assert 'field=name' in response.message
|
||||
@@ -137,7 +140,7 @@ class TestHandlerRagErrorResponse:
|
||||
"""Test error response includes exception type."""
|
||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||
|
||||
error = RuntimeError("connection failed")
|
||||
error = RuntimeError('connection failed')
|
||||
response = _make_rag_error_response(error, 'ConnectionError')
|
||||
|
||||
assert 'RuntimeError' in response.message
|
||||
@@ -148,7 +151,7 @@ class TestHandlerRagErrorResponse:
|
||||
"""Test error response with no extra context."""
|
||||
from langbot.pkg.plugin.handler import _make_rag_error_response
|
||||
|
||||
error = KeyError("missing_key")
|
||||
error = KeyError('missing_key')
|
||||
response = _make_rag_error_response(error, 'LookupError')
|
||||
|
||||
# No context parts means no brackets
|
||||
|
||||
@@ -47,12 +47,14 @@ class TestInitializePluginSettings:
|
||||
Mock(),
|
||||
]
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
'install_source': 'local',
|
||||
'install_info': {'path': '/test'},
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
|
||||
{
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
'install_source': 'local',
|
||||
'install_info': {'path': '/test'},
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert app.persistence_mgr.execute_async.await_count == 2
|
||||
@@ -82,12 +84,14 @@ class TestInitializePluginSettings:
|
||||
Mock(),
|
||||
]
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value]({
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
'install_source': 'github',
|
||||
'install_info': {'repo': 'author/name'},
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.INITIALIZE_PLUGIN_SETTINGS.value](
|
||||
{
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
'install_source': 'github',
|
||||
'install_info': {'repo': 'author/name'},
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert app.persistence_mgr.execute_async.await_count == 3
|
||||
@@ -161,9 +165,7 @@ class TestSetBinaryStorage:
|
||||
runtime_handler = make_handler(app)
|
||||
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'old'))
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
|
||||
self.payload(b'new')
|
||||
)
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'new'))
|
||||
|
||||
assert response.code == 0
|
||||
assert app.persistence_mgr.execute_async.await_count == 2
|
||||
@@ -203,9 +205,7 @@ class TestSetBinaryStorage:
|
||||
runtime_handler = make_handler(app)
|
||||
app.instance_config.data['plugin']['binary_storage']['max_value_bytes'] = 0
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](
|
||||
self.payload(b'x')
|
||||
)
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.SET_BINARY_STORAGE.value](self.payload(b'x'))
|
||||
|
||||
assert response.code != 0
|
||||
assert '1 > 0 bytes' in response.message
|
||||
@@ -228,10 +228,12 @@ class TestGetPluginSettings:
|
||||
runtime_handler = make_handler(app)
|
||||
app.persistence_mgr.execute_async.return_value = make_result()
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
|
||||
{
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {
|
||||
@@ -255,10 +257,12 @@ class TestGetPluginSettings:
|
||||
)
|
||||
app.persistence_mgr.execute_async.return_value = make_result(setting)
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value]({
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_PLUGIN_SETTINGS.value](
|
||||
{
|
||||
'plugin_author': 'test-author',
|
||||
'plugin_name': 'test-plugin',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {
|
||||
@@ -286,11 +290,13 @@ class TestGetBinaryStorage:
|
||||
runtime_handler = make_handler(app)
|
||||
app.persistence_mgr.execute_async.return_value = make_result(SimpleNamespace(value=b'test binary content'))
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
|
||||
'key': 'test-key',
|
||||
'owner_type': 'plugin',
|
||||
'owner': 'test-owner',
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
|
||||
{
|
||||
'key': 'test-key',
|
||||
'owner_type': 'plugin',
|
||||
'owner': 'test-owner',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {
|
||||
@@ -303,11 +309,13 @@ class TestGetBinaryStorage:
|
||||
runtime_handler = make_handler(app)
|
||||
app.persistence_mgr.execute_async.return_value = make_result()
|
||||
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value]({
|
||||
'key': 'test-key',
|
||||
'owner_type': 'plugin',
|
||||
'owner': 'test-owner',
|
||||
})
|
||||
response = await runtime_handler.actions[RuntimeToLangBotAction.GET_BINARY_STORAGE.value](
|
||||
{
|
||||
'key': 'test-key',
|
||||
'owner_type': 'plugin',
|
||||
'owner': 'test-owner',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code != 0
|
||||
assert 'Storage with key test-key not found' in response.message
|
||||
@@ -329,9 +337,11 @@ class TestHandlerQueryLookup:
|
||||
"""Query-bound actions return error when query_id is not cached."""
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
|
||||
'query_id': 'nonexistent-query',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
|
||||
{
|
||||
'query_id': 'nonexistent-query',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code != 0
|
||||
assert 'nonexistent-query' in response.message
|
||||
@@ -343,9 +353,11 @@ class TestHandlerQueryLookup:
|
||||
query = SimpleNamespace(variables={}, bot_uuid='test-bot-uuid')
|
||||
app.query_pool.cached_queries['existing-query'] = query
|
||||
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value]({
|
||||
'query_id': 'existing-query',
|
||||
})
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_BOT_UUID.value](
|
||||
{
|
||||
'query_id': 'existing-query',
|
||||
}
|
||||
)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'bot_uuid': 'test-bot-uuid'}
|
||||
|
||||
@@ -4,6 +4,7 @@ Tests cover:
|
||||
- _make_rag_error_response() helper function
|
||||
- RuntimeConnectionHandler cleanup_plugin_data method
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +24,7 @@ class TestMakeRagErrorResponse:
|
||||
"""Test basic error response creation."""
|
||||
handler = get_handler_module()
|
||||
|
||||
error = ValueError("test error message")
|
||||
error = ValueError('test error message')
|
||||
result = handler._make_rag_error_response(error, 'TestError')
|
||||
|
||||
# ActionResponse.error() returns code=1 (error status)
|
||||
@@ -36,7 +37,7 @@ class TestMakeRagErrorResponse:
|
||||
"""Test that error type is included in message."""
|
||||
handler = get_handler_module()
|
||||
|
||||
error = RuntimeError("something went wrong")
|
||||
error = RuntimeError('something went wrong')
|
||||
result = handler._make_rag_error_response(error, 'VectorStoreError')
|
||||
|
||||
assert '[VectorStoreError/RuntimeError]' in result.message
|
||||
@@ -45,7 +46,7 @@ class TestMakeRagErrorResponse:
|
||||
"""Test that extra context fields are included."""
|
||||
handler = get_handler_module()
|
||||
|
||||
error = Exception("embedding failed")
|
||||
error = Exception('embedding failed')
|
||||
result = handler._make_rag_error_response(
|
||||
error,
|
||||
'EmbeddingError',
|
||||
@@ -71,7 +72,7 @@ class TestMakeRagErrorResponse:
|
||||
"""Test multiple context fields are comma separated."""
|
||||
handler = get_handler_module()
|
||||
|
||||
error = IOError("file not found")
|
||||
error = IOError('file not found')
|
||||
result = handler._make_rag_error_response(
|
||||
error,
|
||||
'FileServiceError',
|
||||
@@ -119,9 +120,7 @@ class TestCleanupPluginData:
|
||||
handler_instance = Mock(spec=handler_module.RuntimeConnectionHandler)
|
||||
handler_instance.ap = mock_app
|
||||
|
||||
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(
|
||||
handler_instance, 'author', 'plugin-name'
|
||||
)
|
||||
await handler_module.RuntimeConnectionHandler.cleanup_plugin_data(handler_instance, 'author', 'plugin-name')
|
||||
|
||||
# Should have at least 2 calls: one for settings, one for binary storage
|
||||
assert mock_app.persistence_mgr.execute_async.call_count >= 2
|
||||
assert mock_app.persistence_mgr.execute_async.call_count >= 2
|
||||
|
||||
@@ -88,7 +88,10 @@ class AnotherFakeRequester(requester.ProviderAPIRequester):
|
||||
|
||||
async def invoke_llm(self, query, model, messages, funcs=None, extra_args={}, remove_think=False):
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
return provider_message.Message(role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')])
|
||||
|
||||
return provider_message.Message(
|
||||
role='assistant', content=[provider_message.ContentElement(type='text', text='Another response')]
|
||||
)
|
||||
|
||||
async def invoke_rerank(self, model, query: str, documents: list, extra_args={}):
|
||||
"""Return fake rerank results."""
|
||||
@@ -135,8 +138,10 @@ def mock_app_for_modelmgr():
|
||||
|
||||
# Fake persistence manager - returns empty results by default
|
||||
app.persistence_mgr = SimpleNamespace()
|
||||
|
||||
async def default_execute(query):
|
||||
return _make_mock_result([])
|
||||
|
||||
app.persistence_mgr.execute_async = AsyncMock(side_effect=default_execute)
|
||||
|
||||
# Fake discover engine
|
||||
@@ -165,9 +170,7 @@ def fake_requester_registry(mock_app_for_modelmgr):
|
||||
fake_component = _create_fake_component('fake-requester', FakeProviderAPIRequester)
|
||||
another_component = _create_fake_component('another-fake-requester', AnotherFakeRequester)
|
||||
|
||||
app.discover.get_components_by_kind = Mock(
|
||||
return_value=[fake_component, another_component]
|
||||
)
|
||||
app.discover.get_components_by_kind = Mock(return_value=[fake_component, another_component])
|
||||
|
||||
model_mgr = ModelManager(app)
|
||||
return model_mgr
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestDifyExtractTextOutput:
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
@@ -111,7 +111,7 @@ class TestDifyRunnerConfigValidation:
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
with pytest.raises(DifyAPIError, match='不支持'):
|
||||
@@ -134,7 +134,7 @@ class TestDifyRunnerConfigValidation:
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
@@ -160,10 +160,10 @@ class TestDifyRunnerInit:
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
'output': {'misc': {}},
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
|
||||
assert runner.pipeline_config == pipeline_config
|
||||
assert runner.ap == mock_app
|
||||
assert runner.ap == mock_app
|
||||
|
||||
@@ -1062,9 +1062,7 @@ class TestScanModels:
|
||||
|
||||
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
|
||||
mock_get_model_info.side_effect = (
|
||||
lambda model: {'max_input_tokens': 131072}
|
||||
if model == 'moonshot/moonshot-v1-128k'
|
||||
else {}
|
||||
lambda model: {'max_input_tokens': 131072} if model == 'moonshot/moonshot-v1-128k' else {}
|
||||
)
|
||||
|
||||
assert requester._safe_context_length('moonshot-v1-128k') == 131072
|
||||
|
||||
@@ -180,7 +180,7 @@ class TestMCPServerBoxConfig:
|
||||
assert cfg.host_path is None
|
||||
assert cfg.host_path_mode == 'ro'
|
||||
assert cfg.env == {}
|
||||
assert cfg.startup_timeout_sec == 300
|
||||
assert cfg.startup_timeout_sec == 120
|
||||
assert cfg.cpus is None
|
||||
assert cfg.memory_mb is None
|
||||
assert cfg.pids_limit is None
|
||||
@@ -494,84 +494,6 @@ class TestBuildBoxProcessPayload:
|
||||
assert payload['args'] == ['/opt/other/server.py', '--flag']
|
||||
|
||||
|
||||
# ── Python Workspace Preparation ────────────────────────────────────
|
||||
|
||||
|
||||
class TestPythonWorkspacePreparation:
|
||||
def test_requirements_workspace_uses_venv_bootstrap(self, mcp_module, tmp_path):
|
||||
host_path = tmp_path / 'mcp-source'
|
||||
host_path.mkdir()
|
||||
(host_path / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8')
|
||||
|
||||
command = mcp_module.BoxStdioSessionRuntime.detect_install_command(
|
||||
str(host_path),
|
||||
'/workspace/.mcp/u1/workspace',
|
||||
)
|
||||
|
||||
assert command is not None
|
||||
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
|
||||
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
|
||||
assert 'python -m pip install -r "/workspace/.mcp/u1/workspace/requirements.txt"' in command
|
||||
assert 'pip install --no-cache-dir -r' not in command
|
||||
|
||||
def test_staging_refresh_removes_stale_source_files_but_preserves_runtime_dirs(self, mcp_module, tmp_path):
|
||||
source = tmp_path / 'source'
|
||||
source.mkdir()
|
||||
(source / 'server.py').write_text('print("new")\n', encoding='utf-8')
|
||||
(source / 'requirements.txt').write_text('mcp==1.26.0\n', encoding='utf-8')
|
||||
(source / '.env').write_text('TOKEN=new\n', encoding='utf-8')
|
||||
|
||||
process_root = tmp_path / 'shared' / '.mcp' / 'u1'
|
||||
workspace = process_root / 'workspace'
|
||||
(workspace / '.venv' / 'bin').mkdir(parents=True)
|
||||
(workspace / '.venv' / 'bin' / 'python').write_text('', encoding='utf-8')
|
||||
(workspace / '.langbot').mkdir()
|
||||
(workspace / '.langbot' / 'python-env.lock').mkdir()
|
||||
(workspace / '.env').write_text('TOKEN=old\n', encoding='utf-8')
|
||||
(workspace / 'server.py').write_text('print("old")\n', encoding='utf-8')
|
||||
(workspace / 'removed.py').write_text('stale\n', encoding='utf-8')
|
||||
(workspace / 'removed_dir').mkdir()
|
||||
(workspace / 'removed_dir' / 'old.txt').write_text('stale\n', encoding='utf-8')
|
||||
|
||||
mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace))
|
||||
|
||||
assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n'
|
||||
assert (workspace / 'requirements.txt').read_text(encoding='utf-8') == 'mcp==1.26.0\n'
|
||||
assert (workspace / '.env').read_text(encoding='utf-8') == 'TOKEN=new\n'
|
||||
assert not (workspace / 'removed.py').exists()
|
||||
assert not (workspace / 'removed_dir').exists()
|
||||
assert (workspace / '.venv' / 'bin' / 'python').exists()
|
||||
assert (workspace / '.langbot' / 'python-env.lock').is_dir()
|
||||
|
||||
def test_staging_refresh_ignores_unlink_race(self, mcp_module, tmp_path, monkeypatch):
|
||||
mcp_stdio_module = sys.modules['langbot.pkg.provider.tools.loaders.mcp_stdio']
|
||||
|
||||
source = tmp_path / 'source'
|
||||
source.mkdir()
|
||||
(source / 'server.py').write_text('print("new")\n', encoding='utf-8')
|
||||
|
||||
process_root = tmp_path / 'shared' / '.mcp' / 'u1'
|
||||
workspace = process_root / 'workspace'
|
||||
workspace.mkdir(parents=True)
|
||||
stale_file = workspace / 'removed.py'
|
||||
stale_file.write_text('stale\n', encoding='utf-8')
|
||||
|
||||
real_unlink = os.unlink
|
||||
|
||||
def unlink_with_race(path):
|
||||
if os.fspath(path) == str(stale_file):
|
||||
real_unlink(path)
|
||||
raise FileNotFoundError(path)
|
||||
real_unlink(path)
|
||||
|
||||
monkeypatch.setattr(mcp_stdio_module.os, 'unlink', unlink_with_race)
|
||||
|
||||
mcp_module.BoxStdioSessionRuntime._copy_workspace_tree(str(source), str(process_root), str(workspace))
|
||||
|
||||
assert not stale_file.exists()
|
||||
assert (workspace / 'server.py').read_text(encoding='utf-8') == 'print("new")\n'
|
||||
|
||||
|
||||
# ── get_runtime_info_dict ───────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -635,7 +635,9 @@ async def test_model_manager_reload_provider_not_found(fake_requester_registry):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_manager_load_llm_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||
async def test_model_manager_load_llm_model_with_provider(
|
||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
||||
):
|
||||
"""Test ModelManager.load_llm_model_with_provider creates RuntimeLLMModel."""
|
||||
model_mgr = fake_requester_registry
|
||||
|
||||
@@ -648,7 +650,9 @@ async def test_model_manager_load_llm_model_with_provider(fake_requester_registr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_manager_load_llm_model_with_provider_from_row(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||
async def test_model_manager_load_llm_model_with_provider_from_row(
|
||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
||||
):
|
||||
"""Test ModelManager.load_llm_model_with_provider handles Row objects."""
|
||||
model_mgr = fake_requester_registry
|
||||
|
||||
@@ -661,7 +665,9 @@ async def test_model_manager_load_llm_model_with_provider_from_row(fake_requeste
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_manager_load_embedding_model_with_provider(fake_requester_registry, fake_persistence_data, runtime_provider):
|
||||
async def test_model_manager_load_embedding_model_with_provider(
|
||||
fake_requester_registry, fake_persistence_data, runtime_provider
|
||||
):
|
||||
"""Test ModelManager.load_embedding_model_with_provider creates RuntimeEmbeddingModel."""
|
||||
model_mgr = fake_requester_registry
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class TestableRequester(requester.ProviderAPIRequester):
|
||||
remove_think=False,
|
||||
):
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
return provider_message.Message(
|
||||
role='assistant',
|
||||
content=[provider_message.ContentElement(type='text', text='Testable response')],
|
||||
@@ -289,7 +290,9 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
|
||||
current_stage_name=None,
|
||||
)
|
||||
|
||||
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||
messages = [
|
||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
||||
]
|
||||
|
||||
result = await provider.invoke_llm(query, runtime_llm_model, messages)
|
||||
|
||||
@@ -330,7 +333,9 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
|
||||
current_stage_name=None,
|
||||
)
|
||||
|
||||
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||
messages = [
|
||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
||||
]
|
||||
|
||||
chunks = []
|
||||
async for chunk in provider.invoke_llm_stream(query, runtime_llm_model, messages):
|
||||
@@ -576,7 +581,9 @@ async def test_runtime_provider_invoke_llm_propagates_error(mock_app_for_modelmg
|
||||
current_stage_name=None,
|
||||
)
|
||||
|
||||
messages = [provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])]
|
||||
messages = [
|
||||
provider_message.Message(role='user', content=[provider_message.ContentElement(type='text', text='Hello')])
|
||||
]
|
||||
|
||||
with pytest.raises(RequesterError):
|
||||
await provider.invoke_llm(query, model, messages)
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests cover:
|
||||
- Conversation creation with prompts
|
||||
- Session concurrency semaphore
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
@@ -60,11 +61,7 @@ class TestSessionManagerGetSession:
|
||||
"""Create mock app with instance config."""
|
||||
mock_app = Mock()
|
||||
mock_app.instance_config = Mock()
|
||||
mock_app.instance_config.data = {
|
||||
'concurrency': {
|
||||
'session': 5
|
||||
}
|
||||
}
|
||||
mock_app.instance_config.data = {'concurrency': {'session': 5}}
|
||||
return mock_app
|
||||
|
||||
@pytest.fixture
|
||||
@@ -173,11 +170,7 @@ class TestSessionManagerGetConversation:
|
||||
"""Create mock app with instance config."""
|
||||
mock_app = Mock()
|
||||
mock_app.instance_config = Mock()
|
||||
mock_app.instance_config.data = {
|
||||
'concurrency': {
|
||||
'session': 5
|
||||
}
|
||||
}
|
||||
mock_app.instance_config.data = {'concurrency': {'session': 5}}
|
||||
return mock_app
|
||||
|
||||
@pytest.fixture
|
||||
@@ -201,17 +194,13 @@ class TestSessionManagerGetConversation:
|
||||
return query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_conversation_with_prompt(
|
||||
self, mock_app_with_config, sample_query, sample_session
|
||||
):
|
||||
async def test_creates_conversation_with_prompt(self, mock_app_with_config, sample_query, sample_session):
|
||||
"""Test that get_conversation creates conversation with prompt."""
|
||||
sessionmgr = get_session_module()
|
||||
|
||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||
|
||||
prompt_config = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||
]
|
||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
||||
pipeline_uuid = 'pipeline-123'
|
||||
bot_uuid = 'bot-123'
|
||||
|
||||
@@ -234,21 +223,15 @@ class TestSessionManagerGetConversation:
|
||||
|
||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||
|
||||
prompt_config = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||
]
|
||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
||||
pipeline_uuid = 'pipeline-123'
|
||||
bot_uuid = 'bot-123'
|
||||
|
||||
# First call creates conversation
|
||||
conv1 = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
|
||||
)
|
||||
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
|
||||
|
||||
# Second call with same pipeline should return same conversation
|
||||
conv2 = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid
|
||||
)
|
||||
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, pipeline_uuid, bot_uuid)
|
||||
|
||||
assert conv1 is conv2
|
||||
assert len(sample_session.conversations) == 1
|
||||
@@ -262,36 +245,26 @@ class TestSessionManagerGetConversation:
|
||||
|
||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||
|
||||
prompt_config = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||
]
|
||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
||||
|
||||
# First call with pipeline1
|
||||
conv1 = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1'
|
||||
)
|
||||
conv1 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-1', 'bot-1')
|
||||
|
||||
# Second call with different pipeline should create new conversation
|
||||
conv2 = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2'
|
||||
)
|
||||
conv2 = await manager.get_conversation(sample_query, sample_session, prompt_config, 'pipeline-2', 'bot-2')
|
||||
|
||||
assert conv1 is not conv2
|
||||
assert len(sample_session.conversations) == 2
|
||||
assert sample_session.using_conversation is conv2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_has_empty_messages(
|
||||
self, mock_app_with_config, sample_query, sample_session
|
||||
):
|
||||
async def test_conversation_has_empty_messages(self, mock_app_with_config, sample_query, sample_session):
|
||||
"""Test that created conversation has empty messages list."""
|
||||
sessionmgr = get_session_module()
|
||||
|
||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||
|
||||
prompt_config = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||
]
|
||||
prompt_config = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
||||
|
||||
conversation = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
||||
@@ -300,22 +273,17 @@ class TestSessionManagerGetConversation:
|
||||
assert conversation.messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_messages_from_config(
|
||||
self, mock_app_with_config, sample_query, sample_session
|
||||
):
|
||||
async def test_prompt_messages_from_config(self, mock_app_with_config, sample_query, sample_session):
|
||||
"""Test that prompt messages are created from prompt_config."""
|
||||
sessionmgr = get_session_module()
|
||||
|
||||
manager = sessionmgr.SessionManager(mock_app_with_config)
|
||||
|
||||
prompt_config = [
|
||||
{'role': 'system', 'content': 'System message'},
|
||||
{'role': 'user', 'content': 'User message'}
|
||||
]
|
||||
prompt_config = [{'role': 'system', 'content': 'System message'}, {'role': 'user', 'content': 'User message'}]
|
||||
|
||||
conversation = await manager.get_conversation(
|
||||
sample_query, sample_session, prompt_config, 'pipeline-123', 'bot-123'
|
||||
)
|
||||
|
||||
assert conversation.prompt.name == 'default'
|
||||
assert len(conversation.prompt.messages) == 2
|
||||
assert len(conversation.prompt.messages) == 2
|
||||
|
||||
@@ -193,29 +193,6 @@ class TestSkillPathHelpers:
|
||||
|
||||
assert list(result.keys()) == ['visible']
|
||||
|
||||
def test_restore_activated_skills_uses_caller_provided_names_and_visibility(self):
|
||||
from langbot.pkg.provider.tools.loaders.skill import (
|
||||
ACTIVATED_SKILLS_KEY,
|
||||
PIPELINE_BOUND_SKILLS_KEY,
|
||||
get_activated_skill_names,
|
||||
restore_activated_skills,
|
||||
)
|
||||
|
||||
ap = _make_ap()
|
||||
ap.skill_mgr = SimpleNamespace(
|
||||
skills={
|
||||
'visible': _make_skill_data(name='visible'),
|
||||
'hidden': _make_skill_data(name='hidden'),
|
||||
}
|
||||
)
|
||||
query = SimpleNamespace(variables={PIPELINE_BOUND_SKILLS_KEY: ['visible']})
|
||||
|
||||
restored = restore_activated_skills(ap, query, ['visible', 'hidden', 'visible', ''])
|
||||
|
||||
assert restored == ['visible']
|
||||
assert list(query.variables[ACTIVATED_SKILLS_KEY].keys()) == ['visible']
|
||||
assert get_activated_skill_names(query) == ['visible']
|
||||
|
||||
def test_resolve_virtual_skill_path_allows_visible_skill_reads(self):
|
||||
from langbot.pkg.provider.tools.loaders.skill import (
|
||||
PIPELINE_BOUND_SKILLS_KEY,
|
||||
@@ -268,8 +245,7 @@ class TestSkillPathHelpers:
|
||||
|
||||
command = wrap_skill_command_with_python_env('python scripts/run.py')
|
||||
|
||||
assert '_LB_SYSTEM_PYTHON="$(command -v python3 || command -v python || true)"' in command
|
||||
assert '"$_LB_SYSTEM_PYTHON" -m venv "$_LB_VENV_DIR"' in command
|
||||
assert 'python -m venv "$_LB_VENV_DIR"' in command
|
||||
assert 'export VIRTUAL_ENV="$_LB_VENV_DIR"' in command
|
||||
assert command.rstrip().endswith('python scripts/run.py')
|
||||
|
||||
@@ -305,7 +281,6 @@ class TestSkillToolLoader:
|
||||
assert result['activated'] is True
|
||||
assert result['skill_name'] == 'demo'
|
||||
assert result['mount_path'] == '/workspace/.skills/demo'
|
||||
assert result['activated_skill_names'] == ['demo']
|
||||
assert 'Step 1' in result['content']
|
||||
assert set(query.variables[ACTIVATED_SKILLS_KEY].keys()) == {'demo'}
|
||||
|
||||
@@ -481,9 +456,7 @@ class TestNativeToolLoaderSkillPaths:
|
||||
SimpleNamespace(query_id='q1', variables={PIPELINE_BOUND_SKILLS_KEY: ['demo']}),
|
||||
)
|
||||
|
||||
assert result['ok'] is True
|
||||
assert result['content'] == 'demo instructions'
|
||||
assert result['truncated'] is False
|
||||
assert result == {'ok': True, 'content': 'demo instructions'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_in_activated_skill_mount_rewrites_command_and_refreshes(self):
|
||||
@@ -512,7 +485,7 @@ class TestNativeToolLoaderSkillPaths:
|
||||
query,
|
||||
)
|
||||
|
||||
assert result['ok'] is True
|
||||
assert result == {'ok': True}
|
||||
tool_parameters = ap.box_service.execute_tool.await_args.args[0]
|
||||
assert tool_parameters['command'] == 'python /workspace/.skills/demo/scripts/run.py'
|
||||
assert tool_parameters['workdir'] == '/workspace/.skills/demo'
|
||||
|
||||
@@ -136,6 +136,7 @@ class TestToolManagerSchemaGeneration:
|
||||
assert 'description' in func
|
||||
assert 'parameters' in func
|
||||
|
||||
|
||||
class TestToolManagerExecuteFuncCall:
|
||||
"""Tests for execute_func_call method."""
|
||||
|
||||
|
||||
@@ -248,135 +248,3 @@ async def test_path_escape_blocked():
|
||||
|
||||
with pytest.raises(ValueError, match='escapes'):
|
||||
await loader.invoke_tool('read', {'path': '/workspace/../../etc/passwd'}, _make_query())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_box_availability_helper_handles_unavailable_and_errors():
|
||||
from langbot.pkg.provider.tools.loaders.availability import is_box_backend_available
|
||||
|
||||
assert await is_box_backend_available(SimpleNamespace()) is False
|
||||
assert await is_box_backend_available(SimpleNamespace(box_service=SimpleNamespace(available=False))) is False
|
||||
|
||||
unavailable_backend = SimpleNamespace(
|
||||
available=True,
|
||||
get_status=AsyncMock(return_value={'backend': {'available': False}}),
|
||||
)
|
||||
assert await is_box_backend_available(SimpleNamespace(box_service=unavailable_backend)) is False
|
||||
|
||||
failing_backend = SimpleNamespace(
|
||||
available=True,
|
||||
get_status=AsyncMock(side_effect=RuntimeError('box unavailable')),
|
||||
)
|
||||
assert await is_box_backend_available(SimpleNamespace(box_service=failing_backend)) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_supports_offset_limit_and_truncation_metadata():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
loader, _ = _make_loader_with_workspace(tmpdir)
|
||||
with open(os.path.join(tmpdir, 'lines.txt'), 'w', encoding='utf-8') as f:
|
||||
f.write('one\ntwo\nthree\nfour\n')
|
||||
|
||||
result = await loader.invoke_tool(
|
||||
'read',
|
||||
{'path': '/workspace/lines.txt', 'offset': 2, 'limit': 2},
|
||||
_make_query(),
|
||||
)
|
||||
|
||||
assert result == {
|
||||
'ok': True,
|
||||
'content': 'two\nthree',
|
||||
'truncated': True,
|
||||
'truncated_by': 'lines',
|
||||
'start_line': 2,
|
||||
'end_line': 3,
|
||||
'next_offset': 4,
|
||||
'max_lines': 2,
|
||||
'max_bytes': 50 * 1024,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handles_line_larger_than_byte_limit():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
loader, _ = _make_loader_with_workspace(tmpdir)
|
||||
with open(os.path.join(tmpdir, 'long-line.txt'), 'w', encoding='utf-8') as f:
|
||||
f.write('abcdef\n')
|
||||
|
||||
result = await loader.invoke_tool(
|
||||
'read',
|
||||
{'path': '/workspace/long-line.txt', 'max_bytes': 3},
|
||||
_make_query(),
|
||||
)
|
||||
|
||||
assert result['ok'] is True
|
||||
assert result['truncated'] is True
|
||||
assert result['truncated_by'] == 'bytes'
|
||||
assert result['next_offset'] == 1
|
||||
assert 'exceeds the 3B read limit' in result['content']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_result_is_capped_and_exposes_preview_metadata():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
box_service = SimpleNamespace(
|
||||
available=True,
|
||||
default_workspace=tmpdir,
|
||||
execute_tool=AsyncMock(
|
||||
return_value={
|
||||
'ok': True,
|
||||
'stdout': 'a' * 60000,
|
||||
'stderr': 'b' * 60000,
|
||||
'exit_code': 0,
|
||||
}
|
||||
),
|
||||
)
|
||||
loader = NativeToolLoader(SimpleNamespace(box_service=box_service, logger=Mock()))
|
||||
|
||||
result = await loader.invoke_tool('exec', {'command': 'python -V'}, _make_query())
|
||||
|
||||
assert result['ok'] is True
|
||||
assert len(result['stdout'].encode('utf-8')) == 50 * 1024
|
||||
assert len(result['stderr'].encode('utf-8')) == 50 * 1024
|
||||
assert len(result['preview'].encode('utf-8')) == 50 * 1024
|
||||
assert result['stdout_truncated'] is True
|
||||
assert result['stderr_truncated'] is True
|
||||
assert result['truncated'] is True
|
||||
assert result['truncated_by'] == 'bytes'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_glob_caps_match_count_and_returns_preview():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
loader, _ = _make_loader_with_workspace(tmpdir)
|
||||
for index in range(105):
|
||||
with open(os.path.join(tmpdir, f'file-{index:03d}.txt'), 'w', encoding='utf-8') as f:
|
||||
f.write(str(index))
|
||||
|
||||
result = await loader.invoke_tool('glob', {'path': '/workspace', 'pattern': '*.txt'}, _make_query())
|
||||
|
||||
assert result['ok'] is True
|
||||
assert result['total'] == 105
|
||||
assert len(result['matches']) == 100
|
||||
assert result['preview'] == '\n'.join(result['matches'])
|
||||
assert result['truncated'] is True
|
||||
assert result['truncated_by'] == 'matches'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grep_reports_invalid_regex_and_truncates_long_matching_lines():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
loader, _ = _make_loader_with_workspace(tmpdir)
|
||||
with open(os.path.join(tmpdir, 'data.txt'), 'w', encoding='utf-8') as f:
|
||||
f.write('needle ' + ('x' * 600) + '\n')
|
||||
|
||||
invalid = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': '['}, _make_query())
|
||||
result = await loader.invoke_tool('grep', {'path': '/workspace', 'pattern': 'needle'}, _make_query())
|
||||
|
||||
assert invalid['ok'] is False
|
||||
assert 'Invalid regex' in invalid['error']
|
||||
assert result['ok'] is True
|
||||
assert result['truncated'] is True
|
||||
assert result['truncated_by'] == 'line'
|
||||
assert result['matches'][0]['file'] == '/workspace/data.txt'
|
||||
assert result['matches'][0]['content'].endswith('... [truncated]')
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Tests cover:
|
||||
- _to_i18n_name() static method
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
@@ -60,4 +61,4 @@ class TestToI18nName:
|
||||
kbmgr = get_kbmgr_module()
|
||||
input_dict = {'en_US': 'English', 'extra_key': 'extra_value'}
|
||||
result = kbmgr.RAGManager._to_i18n_name(input_dict)
|
||||
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
|
||||
assert result == {'en_US': 'English', 'extra_key': 'extra_value'}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user