Compare commits

..

1 Commits

Author SHA1 Message Date
fdc310
92c3b81014 feat: Support for interactive card message processing 2026-05-05 17:44:28 +08:00
215 changed files with 4747 additions and 41409 deletions

View File

@@ -4,29 +4,25 @@ on:
pull_request:
types: [opened, ready_for_review, synchronize]
paths:
- 'src/langbot/**'
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'uv.lock'
- 'run_tests.sh'
- 'scripts/test-*.sh'
push:
branches:
- master
- develop
paths:
- 'src/langbot/**'
- 'pkg/**'
- 'tests/**'
- '.github/workflows/run-tests.yml'
- 'pyproject.toml'
- 'uv.lock'
- 'run_tests.sh'
- 'scripts/test-*.sh'
jobs:
test:
name: Unit Tests
name: Run Unit Tests
runs-on: ubuntu-latest
strategy:
matrix:
@@ -43,13 +39,28 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v4
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install dependencies
run: uv sync --dev
run: |
uv sync --dev
- name: Run unit + smoke tests
run: uv run pytest tests/unit_tests/ tests/smoke/ -q --tb=short
- name: Run unit tests
run: |
bash run_tests.sh
- name: Upload coverage to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unit-tests
name: unit-tests-coverage
fail_ci_if_error: false
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Test Summary
if: always()
@@ -58,79 +69,3 @@ jobs:
echo "" >> $GITHUB_STEP_SUMMARY
echo "Python Version: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
integration:
name: Fast Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run fast integration tests
run: uv run pytest tests/integration/ -m "not slow" -q --tb=short
- name: Integration Test Summary
if: always()
run: |
echo "## Integration Tests Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Test Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY
coverage:
name: Coverage Gate
runs-on: ubuntu-latest
needs: [test, integration]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run coverage (unit + smoke)
run: |
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=xml \
--cov-report=term-missing \
--cov-fail-under=18 \
-q --tb=short
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unit-tests
name: coverage-report
fail_ci_if_error: false
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Coverage Summary
if: always()
run: |
echo "## Coverage Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Threshold: 18%" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ job.status }}" >> $GITHUB_STEP_SUMMARY

View File

@@ -9,13 +9,11 @@ on:
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'src/langbot/pkg/persistence/**'
- 'src/langbot/pkg/entity/persistence/**'
- 'tests/integration/persistence/**'
jobs:
test-migrations-sqlite:
@@ -36,8 +34,52 @@ jobs:
- name: Install dependencies
run: uv sync --dev
- name: Run SQLite migration tests
run: uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
- name: Test Alembic upgrade (SQLite)
run: |
uv run python -c "
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
async def main():
engine = create_async_engine('sqlite+aiosqlite:///test_migrations.db')
# Create all tables (simulates existing DB)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(engine, '0001_baseline')
rev = await get_alembic_current(engine)
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
print(f'Stamped: {rev}')
# Upgrade to head
await run_alembic_upgrade(engine, 'head')
rev = await get_alembic_current(engine)
print(f'After upgrade: {rev}')
assert rev is not None, 'Expected a revision after upgrade'
# Verify idempotent
await run_alembic_upgrade(engine, 'head')
rev2 = await get_alembic_current(engine)
assert rev2 == rev, f'Expected {rev}, got {rev2}'
print(f'Idempotent check passed: {rev2}')
# Fresh DB: upgrade from scratch
engine2 = create_async_engine('sqlite+aiosqlite:///test_migrations_fresh.db')
async with engine2.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_upgrade(engine2, 'head')
rev3 = await get_alembic_current(engine2)
print(f'Fresh DB upgrade: {rev3}')
assert rev3 is not None
print('All SQLite migration tests passed!')
asyncio.run(main())
"
test-migrations-postgres:
name: Migrations (PostgreSQL)
@@ -72,7 +114,58 @@ jobs:
- name: Install dependencies
run: uv sync --dev
- name: Run PostgreSQL migration tests
env:
TEST_POSTGRES_URL: postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test
run: uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
- name: Test Alembic upgrade (PostgreSQL)
run: |
uv run python -c "
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import run_alembic_upgrade, run_alembic_stamp, get_alembic_current
DB_URL = 'postgresql+asyncpg://langbot:langbot@localhost:5432/langbot_test'
async def main():
engine = create_async_engine(DB_URL)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(engine, '0001_baseline')
rev = await get_alembic_current(engine)
assert rev == '0001_baseline', f'Expected 0001_baseline, got {rev}'
print(f'Stamped: {rev}')
# Upgrade to head
await run_alembic_upgrade(engine, 'head')
rev = await get_alembic_current(engine)
print(f'After upgrade: {rev}')
assert rev is not None
# Verify idempotent
await run_alembic_upgrade(engine, 'head')
rev2 = await get_alembic_current(engine)
assert rev2 == rev, f'Expected {rev}, got {rev2}'
print(f'Idempotent check passed: {rev2}')
# Fresh DB: drop all and upgrade from scratch
engine2 = create_async_engine(DB_URL.replace('langbot_test', 'langbot_fresh'))
# Create fresh database
from sqlalchemy import text
async with engine.connect() as conn:
await conn.execute(text('COMMIT'))
await conn.execute(text('CREATE DATABASE langbot_fresh'))
async with engine2.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_upgrade(engine2, 'head')
rev3 = await get_alembic_current(engine2)
print(f'Fresh DB upgrade: {rev3}')
assert rev3 is not None
print('All PostgreSQL migration tests passed!')
asyncio.run(main())
"

View File

@@ -1,36 +0,0 @@
# LangBot Makefile
# Quick developer commands
.PHONY: test test-quick test-integration-fast test-coverage test-all-local lint
# Run all tests (full suite with coverage)
test:
bash run_tests.sh
# Quick self-test for developers (lint + unit + smoke, no real credentials needed)
test-quick:
bash scripts/test-quick.sh
# Fast integration tests (SQLite/API/Pipeline, no external services)
test-integration-fast:
bash scripts/test-integration-fast.sh
# Coverage gate (all tests, enforces minimum threshold)
test-coverage:
bash scripts/test-coverage.sh
# Full local quality gate (quick + integration + coverage)
test-all-local:
bash scripts/test-quick.sh
bash scripts/test-integration-fast.sh
bash scripts/test-coverage.sh
# Run linting only
lint:
ruff check src/langbot/ tests/
ruff format --check src/langbot/ tests/
# Fix linting issues
lint-fix:
ruff check --fix src/langbot/ tests/
ruff format src/langbot/ tests/

View File

@@ -47,8 +47,6 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
[→ Learn more about all features](https://link.langbot.app/en/docs/features)
📍 Practical guides: [deploy a multi-platform AI bot in 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connect DeepSeek to WeChat, Discord, and Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [run a Dify Agent in Discord, Telegram, and Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), and [build an n8n-powered chatbot](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Quick Start

View File

@@ -47,8 +47,6 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
📍 实践指南:[5 分钟部署多平台 AI 机器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[将 DeepSeek 接入微信、企业微信与 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[让 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 构建多平台 AI 聊天机器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
---
## 快速开始

View File

@@ -46,8 +46,6 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
[→ Conocer más sobre todas las funcionalidades](https://link.langbot.app/en/docs/features)
📍 Guías prácticas: [desplegar un bot de IA multiplataforma en 5 minutos](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [conectar DeepSeek a WeChat, Discord y Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [ejecutar un Dify Agent en Discord, Telegram y Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) y [crear un chatbot con n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Inicio Rápido

View File

@@ -46,8 +46,6 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
[→ En savoir plus sur toutes les fonctionnalités](https://link.langbot.app/en/docs/features)
📍 Guides pratiques : [déployer un bot IA multiplateforme en 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connecter DeepSeek à WeChat, Discord et Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [exécuter un Dify Agent dans Discord, Telegram et Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) et [créer un chatbot avec n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Démarrage Rapide

View File

@@ -46,8 +46,6 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
[→ すべての機能について詳しく見る](https://link.langbot.app/ja/docs/features)
📍 実践ガイド: [5分でマルチプラットフォームAIボットをデプロイ](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/)、[DeepSeekをWeChat・Discord・Telegramに接続](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/)、[Dify AgentをDiscord・Telegram・Slackで動かす](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/)、[n8n連携チャットボットを構築](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/)。
---
## クイックスタート

View File

@@ -46,8 +46,6 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
[→ 모든 기능 자세히 보기](https://link.langbot.app/en/docs/features)
📍 실전 가이드: [5분 만에 멀티 플랫폼 AI 봇 배포하기](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [DeepSeek를 WeChat, Discord, Telegram에 연결하기](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [Dify Agent를 Discord, Telegram, Slack에서 실행하기](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), [n8n 기반 챗봇 만들기](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## 빠른 시작

View File

@@ -46,8 +46,6 @@ LangBot — это **платформа с открытым исходным к
[→ Подробнее обо всех возможностях](https://link.langbot.app/en/docs/features)
📍 Практические руководства: [развернуть мультиплатформенного ИИ-бота за 5 минут](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [подключить DeepSeek к WeChat, Discord и Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [запустить Dify Agent в Discord, Telegram и Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) и [создать чат-бота на n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Быстрый старт

View File

@@ -48,8 +48,6 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
📍 實踐指南:[5 分鐘部署多平台 AI 機器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[將 DeepSeek 接入微信、企業微信與 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[讓 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 建構多平台 AI 聊天機器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
---
## 快速開始

View File

@@ -46,8 +46,6 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
[→ Tìm hiểu thêm về tất cả tính năng](https://link.langbot.app/en/docs/features)
📍 Hướng dẫn thực hành: [triển khai bot AI đa nền tảng trong 5 phút](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [kết nối DeepSeek với WeChat, Discord và Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [chạy Dify Agent trên Discord, Telegram và Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) và [xây dựng chatbot với n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
---
## Bắt đầu nhanh

View File

@@ -1,6 +1,6 @@
[project]
name = "langbot"
version = "4.9.7"
version = "4.9.6"
description = "Production-grade platform for building agentic IM bots"
readme = "README.md"
license-files = ["LICENSE"]
@@ -22,7 +22,7 @@ dependencies = [
"discord-py>=2.5.2",
"pynacl>=1.5.0", # Required for Discord voice support
"gewechat-client>=0.1.5",
"lark-oapi>=1.5.5",
"lark-oapi>=1.4.15",
"mcp>=1.25.0",
"nakuru-project-idk>=0.0.2.1",
"ollama>=0.4.8",
@@ -35,7 +35,6 @@ dependencies = [
"python-telegram-bot>=22.0",
"pyyaml>=6.0.2",
"qq-botpy-rc>=1.2.1.6",
"qrcode>=7.4",
"quart>=0.20.0",
"quart-cors>=0.8.0",
"requests>=2.32.3",
@@ -70,7 +69,7 @@ dependencies = [
"chromadb>=1.0.0,<2.0.0",
"qdrant-client (>=1.15.1,<2.0.0)",
"pyseekdb==1.1.0.post3",
"langbot-plugin==0.3.11",
"langbot-plugin==0.3.10",
"asyncpg>=0.30.0",
"line-bot-sdk>=3.19.0",
"matrix-nio>=0.25.2",
@@ -122,7 +121,6 @@ package-data = { "langbot" = ["templates/**", "pkg/provider/modelmgr/requesters/
[dependency-groups]
dev = [
"moto>=5.2.1",
"pre-commit>=4.2.0",
"pytest>=9.0.3",
"pytest-asyncio>=1.0.0",

View File

@@ -4,9 +4,6 @@ python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Python path for imports
pythonpath = . tests
# Test paths
testpaths = tests
@@ -25,9 +22,7 @@ markers =
asyncio: mark test as async
unit: mark test as unit test
integration: mark test as integration test
smoke: mark test as smoke test
slow: mark test as slow running
e2e: mark test as end-to-end test (requires real LangBot process)
# Coverage options (when using pytest-cov)
[coverage:run]

View File

@@ -1,65 +0,0 @@
#!/bin/bash
# Coverage gate script
# Runs all tests with coverage, enforcing minimum coverage threshold
# Uses separate pytest invocations to avoid sys.modules pollution between test types
set -euo pipefail
echo "=== LangBot Coverage Gate ==="
echo ""
# Coverage threshold (baseline from current coverage, conservative buffer)
# Current: ~22.14%, threshold: 18%
COVERAGE_THRESHOLD=18
# Create temporary directory for coverage files
COV_DIR=$(mktemp -d)
trap "rm -rf $COV_DIR" EXIT
echo "[1/3] Running unit + smoke tests with coverage..."
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=json:$COV_DIR/unit.json \
--cov-report=term-missing \
-q --tb=short
echo ""
echo "[2/3] Running fast integration tests with coverage..."
uv run pytest tests/integration/ -m "not slow" \
--cov=langbot \
--cov-report=json:$COV_DIR/integration.json \
--cov-report=term-missing \
-q --tb=short
echo ""
echo "[3/3] Combining coverage reports..."
# Use coverage combine if available, otherwise just report total
if command -v coverage &> /dev/null; then
# Combine JSON reports
coverage combine --keep $COV_DIR/unit.json $COV_DIR/integration.json \
--data-file=$COV_DIR/combined.data 2>/dev/null || true
coverage report --data-file=$COV_DIR/combined.data || true
else
echo "Note: coverage combine not available, showing individual reports above"
fi
# Generate final XML report for CI (from last run)
uv run pytest tests/unit_tests/ tests/smoke/ \
--cov=langbot \
--cov-report=xml:coverage.xml \
--cov-report=term \
--cov-fail-under=$COVERAGE_THRESHOLD \
-q 2>/dev/null || {
# If threshold check fails on combined, check unit+smoke baseline
echo ""
echo "Coverage threshold: $COVERAGE_THRESHOLD%"
echo "Note: Full coverage requires running all test types separately"
}
echo ""
echo "=== Coverage Gate Complete ==="
echo ""
echo "Coverage baseline: $COVERAGE_THRESHOLD%"
echo "Coverage report saved to coverage.xml"

View File

@@ -1,16 +0,0 @@
#!/bin/bash
# Fast integration tests
# Runs integration tests excluding slow ones (PostgreSQL, external services)
# Uses fake runner/provider, no real credentials needed
set -euo pipefail
echo "=== LangBot Fast Integration Tests ==="
echo ""
echo "Running integration tests (excluding slow)..."
uv run pytest tests/integration/ -m "not slow" -q --tb=short
echo ""
echo "=== Fast Integration Tests Complete ==="

View File

@@ -1,36 +0,0 @@
#!/bin/bash
# Quick developer self-test command
# Runs linting, unit tests, and smoke tests without requiring real provider keys
# Suitable for local branch validation
set -euo pipefail
echo "=== LangBot Quick Self-Test ==="
echo ""
# 1. Ruff check
echo "[1/3] Running ruff check..."
uv run ruff check src/langbot/ tests/ --output-format=concise || {
echo ""
echo "⚠ Ruff check found issues. Run 'uv run ruff check --fix' to auto-fix."
exit 1
}
echo "✓ Ruff check passed"
echo ""
# 2. Unit tests
echo "[2/3] Running unit tests..."
uv run pytest tests/unit_tests/ -q --tb=short
echo ""
# 3. Smoke tests (if exists)
echo "[3/3] Running smoke tests..."
if [ -d "tests/smoke" ]; then
uv run pytest tests/smoke/ -q --tb=short
else
echo "No smoke tests found, skipping"
fi
echo ""
echo "=== Quick Self-Test Complete ==="

View File

@@ -1,3 +1,3 @@
"""LangBot - Production-grade platform for building agentic IM bots"""
__version__ = '4.9.7'
__version__ = '4.9.6'

View File

@@ -109,61 +109,6 @@ class AsyncDifyServiceClient:
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def workflow_submit(
self,
form_token: str,
workflow_run_id: str,
inputs: dict[str, typing.Any],
user: str,
action: str = '',
timeout: float = 120.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""Submit human input to resume a paused workflow, then stream events.
1. POST /form/human_input/{form_token} to submit the form
2. GET /workflow/{task_id}/events to stream the resumed workflow events
"""
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
# Step 1: Submit the form
payload: dict[str, typing.Any] = {
'inputs': inputs if isinstance(inputs, dict) else {},
'user': user,
'action': action,
}
submit_resp = await client.post(
f'/form/human_input/{form_token}',
headers=headers,
json=payload,
)
if submit_resp.status_code != 200:
raise DifyAPIError(f'{submit_resp.status_code} {submit_resp.text}')
# Step 2: Stream resumed workflow events
async with client.stream(
'GET',
f'/workflow/{workflow_run_id}/events',
headers={'Authorization': f'Bearer {self.api_key}'},
params={'user': user},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def upload_file(
self,
file: httpx._types.FileTypes,

View File

@@ -1,6 +1,5 @@
import quart
import mimetypes
import asyncio
from ... import group
from langbot.pkg.utils import importutil
@@ -36,617 +35,3 @@ class AdaptersRouterGroup(group.RouterGroup):
return quart.Response(
importutil.read_resource_file_bytes(icon_path), mimetype=mimetypes.guess_type(icon_path)[0]
)
# In-memory session store for active registrations
_create_app_sessions: dict = {}
_SESSION_TTL = 900 # 15 minutes
def _cleanup_expired_sessions():
"""Remove sessions that have exceeded their TTL."""
import time
now = time.time()
expired = [sid for sid, s in _create_app_sessions.items() if now - s.get('created_at', 0) > _SESSION_TTL]
for sid in expired:
session = _create_app_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/lark/create-app', methods=['POST'])
async def _() -> str:
"""Start Feishu one-click app registration. Returns session_id + QR code URL."""
import uuid
import time
import lark_oapi as lark
from lark_oapi.scene.registration.errors import AppAccessDeniedError, AppExpiredError
_cleanup_expired_sessions()
session_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'app_id': None,
'app_secret': None,
'error': None,
'created_at': time.time(),
}
_create_app_sessions[session_id] = session
def on_qr_code(info):
# May be called from a background thread by the SDK;
# use call_soon_threadsafe to safely update session state.
def _update():
session['qr_url'] = info['url']
session['expire_at'] = time.time() + 600 # 10 minutes
session['status'] = 'waiting'
loop.call_soon_threadsafe(_update)
async def run_registration():
try:
result = await lark.aregister_app(
on_qr_code=on_qr_code,
source='langbot',
)
session['status'] = 'success'
session['app_id'] = result['client_id']
session['app_secret'] = result['client_secret']
except AppAccessDeniedError:
session['status'] = 'error'
session['error'] = 'User denied authorization'
except AppExpiredError:
session['status'] = 'error'
session['error'] = 'QR code expired'
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_registration())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url']:
break
await asyncio.sleep(0.5)
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/lark/create-app/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll registration status."""
session = _create_app_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['app_id'] = session['app_id']
data['app_secret'] = session['app_secret']
_create_app_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_create_app_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/lark/create-app/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a registration session."""
session = _create_app_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# WeChat QR Code Login
# -----------------------------------------------------------------------
_weixin_login_sessions: dict = {}
_WEIXIN_SESSION_TTL = 600 # 10 minutes (3 retries × 3 min QR validity)
def _cleanup_expired_weixin_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _weixin_login_sessions.items() if now - s.get('created_at', 0) > _WEIXIN_SESSION_TTL
]
for sid in expired:
session = _weixin_login_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/weixin/login', methods=['POST'])
async def _() -> str:
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
import uuid
import time
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
_cleanup_expired_weixin_sessions()
session_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
session = {
'status': 'pending',
'qr_data_url': None,
'expire_at': None,
'token': None,
'base_url': None,
'account_id': None,
'error': None,
'created_at': time.time(),
}
_weixin_login_sessions[session_id] = session
client = OpenClawWeixinClient(
base_url=DEFAULT_BASE_URL,
token='',
)
async def run_login():
try:
def on_qrcode(qr_data_url: str, _qr_url: str):
def _update():
session['qr_data_url'] = qr_data_url
session['expire_at'] = time.time() + 180
session['status'] = 'waiting'
loop.call_soon_threadsafe(_update)
result = await client.login(
max_retries=1,
poll_timeout_ms=180_000,
on_qrcode=on_qrcode,
)
session['status'] = 'success'
session['token'] = result.token
session['base_url'] = result.base_url
session['account_id'] = result.account_id
except Exception as e:
error_message = str(e)
if 'expired' in error_message.lower() or 'max retries exceeded' in error_message.lower():
session['status'] = 'expired'
session['error'] = 'QR code expired'
else:
session['status'] = 'error'
session['error'] = error_message
finally:
await client.close()
task = asyncio.create_task(run_login())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_data_url']:
break
await asyncio.sleep(0.5)
if not session['qr_data_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_data_url': session['qr_data_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/weixin/login/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll WeChat login status."""
session = _weixin_login_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {
'status': session['status'],
'qr_data_url': session['qr_data_url'],
'expire_at': session['expire_at'],
}
if session['status'] == 'success':
data['token'] = session['token']
data['base_url'] = session['base_url']
data['account_id'] = session['account_id']
_weixin_login_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_weixin_login_sessions.pop(session_id, None)
elif session['status'] == 'expired':
data['error'] = session['error']
_weixin_login_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/weixin/login/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a WeChat login session."""
session = _weixin_login_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# DingTalk Device Flow QR Code Login
# -----------------------------------------------------------------------
_dingtalk_sessions: dict = {}
_DINGTALK_SESSION_TTL = 600 # 10 minutes (QR code validity window)
def _cleanup_expired_dingtalk_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _dingtalk_sessions.items() if now - s.get('created_at', 0) > _DINGTALK_SESSION_TTL
]
for sid in expired:
session = _dingtalk_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/dingtalk/create-app', methods=['POST'])
async def _() -> str:
"""Start DingTalk one-click app creation via Device Flow. Returns session_id + QR code URL."""
import uuid
import time
import aiohttp
DINGTALK_BASE_URL = 'https://oapi.dingtalk.com'
_cleanup_expired_dingtalk_sessions()
session_id = str(uuid.uuid4())
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'client_id': None,
'client_secret': None,
'error': None,
'created_at': time.time(),
'device_code': None,
'interval': 5,
}
_dingtalk_sessions[session_id] = session
async def run_device_flow():
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as http:
# Step 1: Init — get nonce
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/init',
json={'source': 'langbot'},
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from DingTalk service'
return
if data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to init')
return
nonce = data['nonce']
# Step 2: Begin — get device_code + QR URL
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/begin',
json={'nonce': nonce},
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from DingTalk service'
return
if data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to begin authorization')
return
device_code = data['device_code']
verification_uri_complete = data.get('verification_uri_complete', '')
expires_in = data.get('expires_in', 7200)
interval = data.get('interval', 5)
session['device_code'] = device_code
session['interval'] = interval
session['qr_url'] = verification_uri_complete
session['expire_at'] = time.time() + 600 # QR code valid for ~10 min
session['status'] = 'waiting'
# Step 3: Poll for authorization result
deadline = time.time() + expires_in
while time.time() < deadline:
await asyncio.sleep(interval)
async with http.post(
f'{DINGTALK_BASE_URL}/app/registration/poll',
json={'device_code': device_code},
) as poll_resp:
try:
poll_data = await poll_resp.json()
except (aiohttp.ContentTypeError, ValueError):
continue
if poll_data.get('errcode', -1) != 0:
session['status'] = 'error'
session['error'] = poll_data.get('errmsg', 'Poll failed')
return
status = poll_data.get('status', '')
if status == 'SUCCESS':
session['status'] = 'success'
session['client_id'] = poll_data.get('client_id', '')
session['client_secret'] = poll_data.get('client_secret', '')
return
elif status == 'FAIL':
session['status'] = 'error'
session['error'] = poll_data.get('fail_reason', 'Authorization failed')
return
elif status == 'EXPIRED':
session['status'] = 'error'
session['error'] = 'QR code expired'
return
# status == 'WAITING': continue polling
# Timeout
session['status'] = 'error'
session['error'] = 'QR code expired'
except asyncio.CancelledError:
return
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_device_flow())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url'] or session['error']:
break
await asyncio.sleep(0.5)
if session['error']:
task.cancel()
return self.http_status(502, -1, session['error'])
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/dingtalk/create-app/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll DingTalk Device Flow status."""
_cleanup_expired_dingtalk_sessions()
session = _dingtalk_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['client_id'] = session['client_id']
data['client_secret'] = session['client_secret']
_dingtalk_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_dingtalk_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/dingtalk/create-app/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a DingTalk Device Flow session."""
session = _dingtalk_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})
# -----------------------------------------------------------------------
# WeComBot QR Code One-Click Create
# -----------------------------------------------------------------------
_wecombot_sessions: dict = {}
_WECOMBOT_SESSION_TTL = 300 # 5 minutes (WeCom QR validity window)
def _cleanup_expired_wecombot_sessions():
import time
now = time.time()
expired = [
sid for sid, s in _wecombot_sessions.items() if now - s.get('created_at', 0) > _WECOMBOT_SESSION_TTL
]
for sid in expired:
session = _wecombot_sessions.pop(sid, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
@self.route('/wecombot/create-bot', methods=['POST'])
async def _() -> str:
"""Start WeComBot one-click creation via QR code. Returns session_id + QR code URL."""
import uuid
import time
import aiohttp
WECOM_QC_GENERATE_URL = 'https://work.weixin.qq.com/ai/qc/generate'
WECOM_QC_QUERY_URL = 'https://work.weixin.qq.com/ai/qc/query_result'
_cleanup_expired_wecombot_sessions()
session_id = str(uuid.uuid4())
session = {
'status': 'pending',
'qr_url': None,
'expire_at': None,
'botid': None,
'secret': None,
'error': None,
'created_at': time.time(),
'scode': None,
'task': None,
}
_wecombot_sessions[session_id] = session
async def run_qr_flow():
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as http:
# Step 1: Generate QR code
async with http.get(
f'{WECOM_QC_GENERATE_URL}?source=langbot&plat=0',
) as resp:
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
session['status'] = 'error'
session['error'] = 'Invalid response from WeCom service'
return
if not data.get('data', {}).get('scode') or not data.get('data', {}).get('auth_url'):
session['status'] = 'error'
session['error'] = data.get('errmsg', 'Failed to generate QR code')
return
scode = data['data']['scode']
auth_url = data['data']['auth_url']
session['scode'] = scode
session['qr_url'] = auth_url
session['expire_at'] = time.time() + _WECOMBOT_SESSION_TTL
session['status'] = 'waiting'
# Step 2: Poll for scan result
deadline = time.time() + _WECOMBOT_SESSION_TTL
while time.time() < deadline:
await asyncio.sleep(3)
async with http.get(
f'{WECOM_QC_QUERY_URL}?scode={scode}',
) as poll_resp:
try:
poll_data = await poll_resp.json()
except (aiohttp.ContentTypeError, ValueError):
continue
status = poll_data.get('data', {}).get('status', '')
if status == 'success':
bot_info = poll_data.get('data', {}).get('bot_info', {})
if bot_info.get('botid') and bot_info.get('secret'):
session['status'] = 'success'
session['botid'] = bot_info['botid']
session['secret'] = bot_info['secret']
return
else:
session['status'] = 'error'
session['error'] = 'Scan succeeded but bot info is incomplete'
return
# Timeout
session['status'] = 'error'
session['error'] = 'QR code expired'
except asyncio.CancelledError:
return
except Exception as e:
session['status'] = 'error'
session['error'] = str(e)
task = asyncio.create_task(run_qr_flow())
session['task'] = task
# Wait for QR code to be ready (max 10 seconds)
for _ in range(20):
if session['qr_url'] or session['error']:
break
await asyncio.sleep(0.5)
if session['error']:
task.cancel()
return self.http_status(502, -1, session['error'])
if not session['qr_url']:
task.cancel()
session['status'] = 'error'
session['error'] = 'Timeout waiting for QR code'
return self.http_status(504, -1, 'Timeout waiting for QR code')
return self.success(
data={
'session_id': session_id,
'qr_url': session['qr_url'],
'expire_at': session['expire_at'],
}
)
@self.route('/wecombot/create-bot/status/<session_id>', methods=['GET'])
async def _(session_id: str) -> str:
"""Poll WeComBot creation status."""
_cleanup_expired_wecombot_sessions()
session = _wecombot_sessions.get(session_id)
if not session:
return self.http_status(404, -1, 'Session not found')
data = {'status': session['status']}
if session['status'] == 'success':
data['botid'] = session['botid']
data['secret'] = session['secret']
_wecombot_sessions.pop(session_id, None)
elif session['status'] == 'error':
data['error'] = session['error']
_wecombot_sessions.pop(session_id, None)
return self.success(data=data)
@self.route('/wecombot/create-bot/<session_id>', methods=['DELETE'])
async def _(session_id: str) -> str:
"""Cancel and clean up a WeComBot creation session."""
session = _wecombot_sessions.pop(session_id, None)
if session and session.get('task') and not session['task'].done():
session['task'].cancel()
return self.success(data={})

View File

@@ -7,10 +7,8 @@ import httpx
import uuid
import os
import posixpath
import sqlalchemy
from .....core import taskmgr
from .....entity.persistence import plugin as persistence_plugin
from .. import group
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
@@ -41,16 +39,6 @@ def _normalize_plugin_asset_path(filepath: str) -> str | None:
return f'assets/{normalized}'
def _get_request_origin() -> str:
"""Return the public request origin, respecting reverse-proxy headers."""
forwarded_proto = quart.request.headers.get('X-Forwarded-Proto', '').split(',')[0].strip()
forwarded_host = quart.request.headers.get('X-Forwarded-Host', '').split(',')[0].strip()
scheme = forwarded_proto or quart.request.scheme
host = forwarded_host or quart.request.host
return f'{scheme}://{host}'
@group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup):
async def _check_extensions_limit(self) -> str | None:
@@ -150,15 +138,7 @@ class PluginsRouterGroup(group.RouterGroup):
return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_plugin.PluginSetting.config)
.where(persistence_plugin.PluginSetting.plugin_author == author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
)
persisted_config = result.scalar_one_or_none()
config = persisted_config if persisted_config is not None else plugin['plugin_config']
return self.success(data={'config': config})
return self.success(data={'config': plugin['plugin_config']})
elif quart.request.method == 'PUT':
data = await quart.request.json
@@ -209,7 +189,7 @@ class PluginsRouterGroup(group.RouterGroup):
# CSP for HTML pages served to sandboxed iframes (opaque origin).
# 'self' doesn't work in sandboxed iframes — use actual server origin.
if mime_type and mime_type.startswith('text/html'):
origin = _get_request_origin()
origin = f'{quart.request.scheme}://{quart.request.host}'
resp.headers['Content-Security-Policy'] = (
f'default-src {origin}; '
f"script-src {origin} 'unsafe-inline'; "

View File

@@ -140,6 +140,17 @@ class SystemRouterGroup(group.RouterGroup):
async def _() -> str:
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {'ap': ap}))
@self.route(
'/debug/plugin/action',
methods=['POST'],

View File

@@ -146,7 +146,6 @@ class UserRouterGroup(group.RouterGroup):
return self.fail(3, str(e))
except ValueError as e:
traceback.print_exc()
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
return self.fail(1, str(e))
except Exception as e:
traceback.print_exc()

View File

@@ -52,9 +52,6 @@ class ApiKeyService:
async def verify_api_key(self, key: str) -> bool:
"""Verify if an API key is valid"""
if not isinstance(key, str) or not key.startswith('lbk_'):
return False
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key)
)

View File

@@ -99,11 +99,11 @@ class BotService:
# TODO: 检查配置信息格式
bot_data['uuid'] = str(uuid.uuid4())
# bind the most recently updated pipeline if any exist
# checkout the default pipeline
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
.order_by(persistence_pipeline.LegacyPipeline.updated_at.desc())
.limit(1)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None:
@@ -120,26 +120,24 @@ class BotService:
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
"""Update bot"""
update_data = bot_data.copy()
if 'uuid' in update_data:
del update_data['uuid']
if 'uuid' in bot_data:
del bot_data['uuid']
# set use_pipeline_name
if 'use_pipeline_uuid' in update_data:
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == update_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
update_data['use_pipeline_name'] = pipeline.name
bot_data['use_pipeline_name'] = pipeline.name
else:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(update_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)

View File

@@ -31,126 +31,15 @@ class KnowledgeService:
if not knowledge_engine_plugin_id:
raise ValueError('knowledge_engine_plugin_id is required')
creation_settings = kb_data.get('creation_settings', {})
retrieval_settings = kb_data.get('retrieval_settings', {})
# Validate required fields based on plugin's creation_schema and retrieval_schema
await self._validate_schema_required_fields(
knowledge_engine_plugin_id,
creation_settings,
retrieval_settings,
)
kb = await self.ap.rag_mgr.create_knowledge_base(
name=kb_data.get('name', 'Untitled'),
knowledge_engine_plugin_id=knowledge_engine_plugin_id,
creation_settings=creation_settings,
retrieval_settings=retrieval_settings,
creation_settings=kb_data.get('creation_settings', {}),
retrieval_settings=kb_data.get('retrieval_settings', {}),
description=kb_data.get('description', ''),
)
return kb.uuid
async def _validate_schema_required_fields(
self,
plugin_id: str,
creation_settings: dict,
retrieval_settings: dict,
) -> None:
"""Validate required fields based on plugin's creation_schema and retrieval_schema.
This is a business-agnostic validation that checks all fields marked as
required in the plugin's schema, regardless of field type.
Args:
plugin_id: Knowledge Engine plugin ID.
creation_settings: User-provided creation settings.
retrieval_settings: User-provided retrieval settings.
Raises:
ValueError: If any required field is missing or empty.
"""
# Validate creation_schema
try:
creation_schema = await self.ap.plugin_connector.get_rag_creation_schema(plugin_id)
self._check_required_fields(creation_schema, creation_settings, 'creation_settings')
except ValueError:
raise
except Exception as e:
self.ap.logger.warning(f'Failed to get creation_schema for validation: {e}')
# Validate retrieval_schema
try:
retrieval_schema = await self.ap.plugin_connector.get_rag_retrieval_schema(plugin_id)
self._check_required_fields(retrieval_schema, retrieval_settings, 'retrieval_settings')
except ValueError:
raise
except Exception as e:
self.ap.logger.warning(f'Failed to get retrieval_schema for validation: {e}')
def _check_required_fields(
self,
schema: dict | list,
settings: dict,
context: str,
) -> None:
"""Check required fields in schema against provided settings.
Args:
schema: Plugin-defined schema (can be list or dict with 'schema' key).
settings: User-provided settings values.
context: Context name for error messages (e.g., 'creation_settings').
Raises:
ValueError: If a required field is missing or empty.
"""
if not schema:
return
# schema can be a list directly, or a dict with 'schema' key
items = schema if isinstance(schema, list) else schema.get('schema', [])
if not items:
return
for item in items:
field_name = item.get('name')
if not field_name:
continue
is_required = item.get('required', False)
if not is_required:
continue
# Check show_if condition - if field is conditionally shown, only validate when condition is met
show_if = item.get('show_if')
if show_if:
depend_field = show_if.get('field')
operator = show_if.get('operator')
expected_value = show_if.get('value')
if depend_field and operator:
depend_value = settings.get(depend_field)
# If show_if condition is not met, skip validation for this field
if operator == 'eq' and depend_value != expected_value:
continue
if operator == 'neq' and depend_value == expected_value:
continue
if operator == 'in' and isinstance(expected_value, list) and depend_value not in expected_value:
continue
value = settings.get(field_name)
# Validate required field has a non-empty value
if value is None or (isinstance(value, str) and value.strip() == ''):
# Get field label for friendly error message
label = item.get('label', {})
field_label = (
label.get('en_US', field_name)
or label.get('zh_Hans', field_name)
or label.get('zh_Hant', field_name)
or field_name
)
raise ValueError(f'{field_label} is required ({context}.{field_name})')
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
"""更新知识库"""
# Filter to only mutable fields

View File

@@ -113,9 +113,14 @@ class PipelineService:
return pipeline_data['uuid']
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
pipeline_data = pipeline_data.copy()
for protected_field in ('uuid', 'for_version', 'stages', 'is_default'):
pipeline_data.pop(protected_field, None)
if 'uuid' in pipeline_data:
del pipeline_data['uuid']
if 'for_version' in pipeline_data:
del pipeline_data['for_version']
if 'stages' in pipeline_data:
del pipeline_data['stages']
if 'is_default' in pipeline_data:
del pipeline_data['is_default']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline)

View File

@@ -17,24 +17,6 @@ class ModelProviderService:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
@staticmethod
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
if api_keys is None:
return []
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
normalized_keys = []
seen_keys = set()
for raw_key in raw_keys:
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
if not normalized_key or normalized_key in seen_keys:
continue
normalized_keys.append(normalized_key)
seen_keys.add(normalized_key)
return normalized_keys
async def get_providers(self) -> list[dict]:
"""Get all providers"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
@@ -77,7 +59,6 @@ class ModelProviderService:
async def create_provider(self, provider_data: dict) -> str:
"""Create a new provider"""
provider_data['uuid'] = str(uuid.uuid4())
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
@@ -91,8 +72,6 @@ class ModelProviderService:
"""Update an existing provider"""
if 'uuid' in provider_data:
del provider_data['uuid']
if 'api_keys' in provider_data:
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == provider_uuid)
@@ -162,8 +141,6 @@ class ModelProviderService:
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
"""Find existing provider or create new one"""
api_keys = self._normalize_api_keys(api_keys)
# Try to find existing provider with same config
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
@@ -191,7 +168,7 @@ class ModelProviderService:
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys,
'api_keys': api_keys or [],
}
)
@@ -200,7 +177,7 @@ class ModelProviderService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
.values(api_keys=self._normalize_api_keys(api_key))
.values(api_keys=[api_key])
)
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')

View File

@@ -46,14 +46,12 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
async def main(loop: asyncio.AbstractEventLoop):
app_inst: app.Application | None = None
try:
# Hang system signal processing
import signal
def signal_handler(sig, frame):
if app_inst is not None:
app_inst.dispose()
app_inst.dispose()
print('[Signal] Program exit.')
os._exit(0)

View File

@@ -275,7 +275,6 @@ class MessageAggregator:
message_chain=merged_chain,
adapter=base_msg.adapter,
pipeline_uuid=base_msg.pipeline_uuid,
routed_by_rule=any(msg.routed_by_rule for msg in messages),
)
async def flush_all(self) -> None:

View File

@@ -76,10 +76,6 @@ class LongTextProcessStage(stage.PipelineStage):
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
if not query.resp_message_chain:
self.ap.logger.debug('Response message chain is empty, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
# 检查是否包含非 Plain 组件
contains_non_plain = False

View File

@@ -157,7 +157,7 @@ class RuntimePipeline:
bot_message=query.resp_messages[-1],
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
is_final=[msg.is_final for msg in query.resp_messages][-1],
is_final=[msg.is_final for msg in query.resp_messages][0],
)
else:
await query.adapter.reply_message(

View File

@@ -42,13 +42,9 @@ class QueryPool:
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None,
routed_by_rule: bool = False,
variables: typing.Optional[dict[str, typing.Any]] = None,
) -> pipeline_query.Query:
async with self.condition:
query_id = self.query_id_counter
initial_variables: dict[str, typing.Any] = {'_routed_by_rule': routed_by_rule}
if variables:
initial_variables.update(variables)
query = pipeline_query.Query(
bot_uuid=bot_uuid,
query_id=query_id,
@@ -57,7 +53,7 @@ class QueryPool:
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
variables=initial_variables,
variables={'_routed_by_rule': routed_by_rule},
resp_messages=[],
resp_message_chain=[],
adapter=adapter,
@@ -67,7 +63,6 @@ class QueryPool:
self.cached_queries[query_id] = query
self.query_id_counter += 1
self.condition.notify_all()
return query
async def __aenter__(self):
await self.pool_lock.acquire()

View File

@@ -40,7 +40,7 @@ class SendResponseBackStage(stage.PipelineStage):
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][-1]
is_final = [msg.is_final for msg in query.resp_messages][0]
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],

View File

@@ -501,8 +501,6 @@ class PlatformManager:
bot_entity.adapter_config,
logger,
)
if hasattr(adapter_inst, 'ap'):
adapter_inst.ap = self.ap
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid用于统一 webhook
if hasattr(adapter_inst, 'set_bot_uuid'):

View File

@@ -3,7 +3,6 @@ import typing
import asyncio
import traceback
import datetime
import json
import aiocqhttp
import pydantic
@@ -294,29 +293,6 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
elif msg.type == 'dice':
face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
elif msg.type == 'json':
try:
raw = msg.data.get('data', {})
if isinstance(raw, str):
raw = json.loads(raw)
if isinstance(raw, dict):
_meta = raw.get('meta', {}) or {}
if isinstance(_meta, dict):
_detail = _meta.get('detail_1') or _meta.get('music') or _meta.get('news') or {}
else:
_detail = {}
if isinstance(_detail, dict):
preview = _detail.get('preview', '')
title = _detail.get('desc', '') or _detail.get('title', '')
url = _detail.get('qqdocurl', '') or _detail.get('jumpUrl', '')
else:
preview = title = url = ''
text = ' '.join([f'[{raw.get("app", "")}]', preview, title, url]).strip()
yiri_msg_list.append(platform_message.Plain(text=text or '[收到一张JSON卡片]'))
else:
yiri_msg_list.append(platform_message.Plain(text=str(raw)))
except Exception:
yiri_msg_list.append(platform_message.Plain(text='[收到一张JSON卡片]'))
chain = platform_message.MessageChain(yiri_msg_list)

View File

@@ -19,18 +19,6 @@ spec:
en: https://link.langbot.app/en/platforms/dingtalk
ja: https://link.langbot.app/ja/platforms/dingtalk
config:
- name: one-click-create
label:
en_US: One-Click Create App
zh_Hans: 一键创建应用
zh_Hant: 一鍵建立應用
description:
en_US: "Scan QR code with DingTalk to automatically create an app and fill in credentials. Note: Robot Code cannot be obtained automatically, you need to copy it from the DingTalk Developer Backend manually."
zh_Hans: "使用钉钉扫码自动创建应用并填写凭据。注意:机器人代码无法自动获取,需前往钉钉开发者后台手动复制。"
zh_Hant: "使用釘釘掃碼自動建立應用並填寫憑證。注意:機器人代碼無法自動取得,需前往釘釘開發者後台手動複製。"
type: qr-code-login
login_platform: dingtalk
required: false
- name: client_id
label:
en_US: Client ID
@@ -52,10 +40,6 @@ spec:
en_US: Robot Code
zh_Hans: 机器人代码
zh_Hant: 機器人代碼
description:
en_US: "Required for image recognition, file upload and other features. Get it from DingTalk Developer Backend > Robot Configuration."
zh_Hans: "识图、上传文件等功能必填。请前往钉钉开发者后台 > 机器人配置中获取。"
zh_Hant: "識圖、上傳檔案等功能必填。請前往釘釘開發者後台 > 機器人設定中取得。"
type: string
required: true
default: ""

File diff suppressed because it is too large Load Diff

View File

@@ -23,20 +23,6 @@ spec:
en: https://link.langbot.app/en/platforms/lark
ja: https://link.langbot.app/ja/platforms/lark
config:
- name: one-click-create
label:
en_US: One-Click Create App
zh_Hans: 一键创建应用
zh_Hant: 一鍵建立應用
ja_JP: ワンクリックでアプリ作成
description:
en_US: Scan QR code to automatically create a Feishu app and fill in credentials
zh_Hans: 扫码自动创建飞书应用并填写凭据
zh_Hant: 掃碼自動建立飛書應用並填寫憑證
ja_JP: QRコードをスキャンしてFeishuアプリを自動作成し、認証情報を入力
type: qr-code-login
login_platform: feishu
required: false
- name: app_id
label:
en_US: App ID

View File

@@ -32,20 +32,6 @@ spec:
type: string
required: true
default: "https://ilinkai.weixin.qq.com"
- name: qr-login
label:
en_US: Scan QR Login
zh_Hans: 扫码登录
zh_Hant: 掃碼登入
ja_JP: QRコードでログイン
description:
en_US: Scan QR code with WeChat to authorize and automatically fill in the token
zh_Hans: 使用微信扫码授权,自动填写令牌
zh_Hant: 使用微信掃碼授權,自動填寫令牌
ja_JP: WeChatでQRコードをスキャンし、トークンを自動入力
type: qr-code-login
login_platform: weixin
required: false
- name: token
label:
en_US: Token

View File

@@ -1,14 +1,14 @@
from __future__ import annotations
import time
import telegram
import telegram.ext
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, CallbackQueryHandler, filters
from telegram import Update
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters
import telegramify_markdown
import typing
import traceback
import json
import base64
import pydantic
@@ -189,7 +189,6 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: telegram.Bot = pydantic.Field(exclude=True)
application: telegram.ext.Application = pydantic.Field(exclude=True)
ap: typing.Any = pydantic.Field(exclude=True, default=None)
message_converter: TelegramMessageConverter = TelegramMessageConverter()
event_converter: TelegramEventConverter = TelegramEventConverter()
@@ -225,102 +224,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
telegram_callback,
)
)
async def callback_query_handler(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.callback_query
await query.answer()
try:
data = json.loads(query.data)
if data.get('form_action') or data.get('f'):
import langbot_plugin.api.entities.builtin.provider.session as provider_session
workflow_run_id = data.get('workflow_run_id', '')
w_suffix = data.get('w', '')
action_id = data.get('action_id') or data.get('a', '')
session_key = data.get('session_key') or data.get('s', '')
if session_key.startswith('group_') or session_key.startswith('g:'):
launcher_type = provider_session.LauncherTypes.GROUP
launcher_id = (
session_key.split(':', 1)[1]
if session_key.startswith('g:')
else session_key[len('group_') :]
)
else:
launcher_type = provider_session.LauncherTypes.PERSON
launcher_id = (
session_key.split(':', 1)[1]
if session_key.startswith('p:')
else session_key[len('person_') :]
)
user_id = str(query.from_user.id)
# Find bot_uuid and pipeline_uuid
bot_uuid = ''
pipeline_uuid = None
for b in self.ap.platform_mgr.bots:
if b.adapter is self:
bot_uuid = b.bot_entity.uuid
pipeline_uuid = b.bot_entity.use_pipeline_uuid
break
form_action_data = {
'workflow_run_id': workflow_run_id,
'w_suffix': w_suffix,
'action_id': action_id,
'user': f'{launcher_type.value}_{launcher_id}',
'inputs': {},
}
message_chain = platform_message.MessageChain(
[platform_message.Plain(text=f'[Form Action: {action_id}]')]
)
if launcher_type == provider_session.LauncherTypes.GROUP:
synthetic_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=user_id,
member_name='',
permission=platform_entities.Permission.Member,
group=platform_entities.Group(
id=launcher_id,
name='',
permission=platform_entities.Permission.Member,
),
),
message_chain=message_chain,
source_platform_object=update,
)
else:
synthetic_event = platform_events.FriendMessage(
sender=platform_entities.Friend(
id=user_id,
nickname='',
remark='',
),
message_chain=message_chain,
source_platform_object=update,
)
await self.ap.query_pool.add_query(
bot_uuid=bot_uuid,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=user_id,
message_event=synthetic_event,
message_chain=message_chain,
adapter=self,
pipeline_uuid=pipeline_uuid,
variables={
'_dify_form_action': form_action_data,
'_routed_by_rule': True,
},
)
except Exception:
await self.logger.error(f'Error in telegram callback query: {traceback.format_exc()}')
application.add_handler(CallbackQueryHandler(callback_query_handler))
super().__init__(
config=config,
logger=logger,
@@ -416,19 +319,14 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
update = event.source_platform_object
chat_id = update.effective_chat.id
chat_type = update.effective_chat.type
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
message_thread_id = update.message.message_thread_id
if chat_type == 'private':
import time as _time
draft_id = int(_time.time() * 1000)
draft_id = int(time.time() * 1000)
self.msg_stream_id[message_id] = ('private', draft_id)
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
try:
await self.bot.send_message_draft(**args)
except (telegram.error.RetryAfter, telegram.error.BadRequest):
pass
await self.bot.send_message_draft(**args)
else:
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
send_msg = await self.bot.send_message(**args)
@@ -449,13 +347,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
assert isinstance(message_source.source_platform_object, Update)
update = message_source.source_platform_object
chat_id = update.effective_chat.id
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
message_thread_id = update.message.message_thread_id
if message_id not in self.msg_stream_id:
return
chat_mode, stream_id = self.msg_stream_id[message_id]
chat_mode, draft_id = self.msg_stream_id[message_id]
components = await TelegramMessageConverter.yiri2target(message, self.bot)
if not components or components[0]['type'] != 'text':
@@ -464,42 +361,16 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
return
content = components[0]['text']
form_data = getattr(bot_message, '_form_data', None)
if form_data and is_final:
self.msg_stream_id.pop(message_id, None)
await self._send_form_action_buttons(message_source, form_data)
return
if chat_mode == 'private':
# Streaming via draft (ephemeral preview in the chat input area)
if (msg_seq - 1) % 8 == 0 or is_final:
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=stream_id)
try:
await self.bot.send_message_draft(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = content[:4000] + '\n\n… (truncated)'
try:
await self.bot.send_message_draft(**args)
except telegram.error.RetryAfter:
pass
else:
pass # Ignore other draft errors (cosmetic)
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
await self.bot.send_message_draft(**args)
if is_final and bot_message.tool_calls is None:
# Finalise: send the real message, discard the draft
args = self._build_message_args(chat_id, content, message_thread_id)
try:
await self.bot.send_message(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = content[:4000] + '\n\n… (truncated)'
await self.bot.send_message(**args)
else:
raise
del args['draft_id']
await self.bot.send_message(**args)
self.msg_stream_id.pop(message_id)
else:
# Streaming via edit_message_text (persistent message)
stream_id = draft_id
if (msg_seq - 1) % 8 == 0 or is_final:
args = {
'message_id': stream_id,
@@ -508,68 +379,11 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
}
if self.config.get('markdown_card', False):
args['parse_mode'] = 'MarkdownV2'
try:
await self.bot.edit_message_text(**args)
except telegram.error.BadRequest as exc:
if 'Message_too_long' in str(exc):
args['text'] = self._process_markdown(content[:4000] + '\n\n… (truncated)')
await self.bot.edit_message_text(**args)
else:
raise
await self.bot.edit_message_text(**args)
if is_final and bot_message.tool_calls is None:
self.msg_stream_id.pop(message_id)
async def _send_form_action_buttons(
self,
message_source: platform_events.MessageEvent,
form_data: dict,
):
"""Send inline keyboard buttons for Dify human_input_required form actions."""
actions = form_data.get('actions', [])
node_title = form_data.get('node_title', '')
form_content = form_data.get('form_content', '')
workflow_run_id = form_data.get('workflow_run_id', '')
# Telegram callback_data is capped at 64 bytes, so we identify the
# paused workflow by the last 8 chars of workflow_run_id (unique
# within a session with overwhelming probability).
w_suffix = workflow_run_id[-8:] if workflow_run_id else ''
if isinstance(message_source, platform_events.GroupMessage):
session_key = f'g:{message_source.group.id}'
else:
session_key = f'p:{message_source.sender.id}'
keyboard = []
for action in actions:
action_id = action.get('id', '')
action_title = action.get('title', action_id)
callback_payload = {'f': 1, 'a': action_id, 's': session_key}
if w_suffix:
callback_payload['w'] = w_suffix
callback_data = json.dumps(callback_payload, separators=(',', ':'))
keyboard.append([InlineKeyboardButton(action_title, callback_data=callback_data)])
reply_markup = InlineKeyboardMarkup(keyboard)
update = message_source.source_platform_object
chat_id = update.effective_chat.id
effective_message = update.effective_message
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
text_lines = [f'[{node_title}] Please select an action:']
if form_content:
text_lines.insert(0, form_content)
args = {
'chat_id': chat_id,
'text': '\n\n'.join(text_lines),
'reply_markup': reply_markup,
}
if message_thread_id:
args['message_thread_id'] = message_thread_id
await self.bot.send_message(**args)
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
if not isinstance(event.source_platform_object, Update):
return None

View File

@@ -27,7 +27,10 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
_ws_adapter: typing.Any = None
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
# Allow private attributes
underscore_attrs_are_private = True
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(config=config, logger=logger, **kwargs)

View File

@@ -19,18 +19,6 @@ spec:
en: https://link.langbot.app/en/platforms/wecombot
ja: https://link.langbot.app/ja/platforms/wecombot
config:
- name: one-click-create
label:
en_US: One-Click Create Bot
zh_Hans: 一键创建机器人
zh_Hant: 一鍵建立機器人
description:
en_US: "Scan QR code with WeCom to automatically create a bot and fill in BotId and Secret. Note: Robot Name needs to be filled in manually."
zh_Hans: "使用企业微信扫码自动创建机器人并填写 BotId 和 Secret。注意机器人名称需手动填写。"
zh_Hant: "使用企業微信掃碼自動建立機器人並填寫 BotId 和 Secret。注意機器人名稱需手動填寫。"
type: qr-code-login
login_platform: wecombot
required: false
- name: BotId
label:
en_US: BotId

View File

@@ -11,7 +11,6 @@ import os
import sys
import httpx
import sqlalchemy
import yaml
from async_lru import alru_cache
from langbot_plugin.api.entities.builtin.pipeline.query import provider_session
@@ -35,10 +34,6 @@ from ..core import taskmgr
from ..entity.persistence import plugin as persistence_plugin
class PluginRuntimeNotConnectedError(RuntimeError):
"""Raised when plugin runtime operations are requested before connection."""
class PluginRuntimeConnector:
"""Plugin runtime connector"""
@@ -196,114 +191,44 @@ class PluginRuntimeConnector:
async def ping_plugin_runtime(self):
if not hasattr(self, 'handler'):
raise PluginRuntimeNotConnectedError('Plugin runtime is not connected')
raise Exception('Plugin runtime is not connected')
return await self.handler.ping()
def _inspect_plugin_package(
def _extract_deps_metadata(
self,
file_bytes: bytes,
task_context: taskmgr.TaskContext | None,
) -> tuple[str | None, str | None]:
"""Extract plugin identity and dependency metadata from a plugin package."""
plugin_author = None
plugin_name = None
):
"""Extract dependency count from requirements.txt inside plugin zip."""
if task_context is None:
return
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
try:
manifest = yaml.safe_load(zf.read('manifest.yaml').decode('utf-8', errors='ignore')) or {}
metadata = manifest.get('metadata', {})
plugin_author = metadata.get('author')
plugin_name = metadata.get('name')
except Exception:
pass
if task_context is not None:
for name in zf.namelist():
if name.endswith('requirements.txt'):
content = zf.read(name).decode('utf-8', errors='ignore')
deps = [
line.strip()
for line in content.splitlines()
if line.strip() and not line.strip().startswith('#')
]
task_context.metadata['deps_total'] = len(deps)
task_context.metadata['deps_list'] = deps
break
for name in zf.namelist():
if name.endswith('requirements.txt'):
content = zf.read(name).decode('utf-8', errors='ignore')
deps = [
line.strip()
for line in content.splitlines()
if line.strip() and not line.strip().startswith('#')
]
task_context.metadata['deps_total'] = len(deps)
task_context.metadata['deps_list'] = deps
break
except Exception:
pass
return plugin_author, plugin_name
def _build_plugin_startup_failure_message(
self,
plugin_author: str,
plugin_name: str,
task_context: taskmgr.TaskContext | None,
) -> str:
dep_hint = ''
if task_context is not None:
current_dep = task_context.metadata.get('current_dep')
if current_dep:
dep_hint = f' Last dependency: {current_dep}.'
return (
f'Plugin {plugin_author}/{plugin_name} failed to start after installation. '
f'Dependency installation or plugin initialization may have failed.{dep_hint} '
f'Please check the plugin requirements and runtime logs.'
)
async def _wait_for_installed_plugin_ready(
self,
plugin_author: str | None,
plugin_name: str | None,
task_context: taskmgr.TaskContext | None,
timeout: float = 30,
):
"""Wait until the installed plugin is registered by the runtime.
The plugin runtime launches plugins asynchronously. If dependency installation
fails, the plugin process exits before registration; without this check the
install task can incorrectly finish successfully.
"""
if not plugin_author or not plugin_name:
return
deadline = time.time() + timeout
last_error: Exception | None = None
while time.time() < deadline:
try:
plugin = await self.get_plugin_info(plugin_author, plugin_name)
if plugin is not None:
status = plugin.get('status')
if status == 'initialized':
return
except Exception as e:
last_error = e
await asyncio.sleep(0.5)
message = self._build_plugin_startup_failure_message(plugin_author, plugin_name, task_context)
if last_error is not None:
message = f'{message} Last runtime error: {last_error}'
raise RuntimeError(message)
async def install_plugin(
self,
install_source: PluginInstallSource,
install_info: dict[str, Any],
task_context: taskmgr.TaskContext | None = None,
):
plugin_author = install_info.get('plugin_author')
plugin_name = install_info.get('plugin_name')
if install_source == PluginInstallSource.LOCAL:
# transfer file before install
file_bytes = install_info['plugin_file']
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
self._extract_deps_metadata(file_bytes, task_context)
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
del install_info['plugin_file']
@@ -340,9 +265,7 @@ class PluginRuntimeConnector:
task_context.metadata['download_speed'] = downloaded / elapsed if elapsed > 0 else 0
file_bytes = b''.join(chunks)
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
self._extract_deps_metadata(file_bytes, task_context)
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
@@ -366,8 +289,6 @@ class PluginRuntimeConnector:
if metadata is not None and task_context is not None:
task_context.metadata.update(metadata)
await self._wait_for_installed_plugin_ready(plugin_author, plugin_name, task_context)
async def upgrade_plugin(
self,
plugin_author: str,
@@ -637,12 +558,11 @@ class PluginRuntimeConnector:
Raises:
ValueError: If plugin_id is not in the expected 'author/name' format.
"""
segments = plugin_id.split('/')
if len(segments) != 2 or not all(segments):
if '/' not in plugin_id:
raise ValueError(
f"Invalid plugin_id format: '{plugin_id}'. Expected 'author/name' format (e.g. 'langbot/rag-engine')."
)
return segments[0], segments[1]
return plugin_id.split('/', 1)
async def call_rag_ingest(self, plugin_id: str, context_data: dict[str, Any]) -> dict[str, Any]:
"""Call plugin to ingest document.

View File

@@ -340,7 +340,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
"""Provider API请求器"""
name: str = None
init_api_key: str = 'langbot-init-placeholder'
ap: app.Application

View File

@@ -25,7 +25,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key=self.init_api_key,
api_key='',
base_url=self.requester_cfg['base_url'].replace(' ', ''),
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key=self.init_api_key,
api_key='',
base_url=self.requester_cfg['base_url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -14,14 +14,7 @@ class TokenManager:
def __init__(self, name: str, tokens: list[str]):
self.name = name
self.tokens = []
seen_tokens = set()
for token in tokens:
normalized_token = token.strip() if isinstance(token, str) else ''
if not normalized_token or normalized_token in seen_tokens:
continue
self.tokens.append(normalized_token)
seen_tokens.add(normalized_token)
self.tokens = tokens
self.using_token_index = 0
def get_token(self) -> str:
@@ -30,6 +23,4 @@ class TokenManager:
return self.tokens[self.using_token_index]
def next_token(self):
if len(self.tokens) == 0:
return
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)

View File

@@ -2,11 +2,9 @@ from __future__ import annotations
import typing
import json
import time
import uuid
import base64
import mimetypes
from collections import OrderedDict
from langbot.pkg.provider import runner
@@ -18,102 +16,6 @@ from langbot.libs.dify_service_api.v1 import client, errors
import httpx
# Module-level store for paused-workflow form state, keyed by session key
# (launcher_type_value + "_" + launcher_id). Each session holds an
# insertion-ordered dict of form_token -> form_data, allowing multiple
# Dify workflows to be paused simultaneously for the same session.
_PENDING_FORMS: dict[str, 'OrderedDict[str, dict[str, typing.Any]]'] = {}
_PENDING_FORM_DEFAULT_TTL = 30 * 60 # 30 minutes safety cap
def _session_key_from_query(query: pipeline_query.Query) -> str:
return f'{query.session.launcher_type.value}_{query.session.launcher_id}'
def _prune_pending_forms(now: float | None = None) -> None:
if now is None:
now = time.time()
for session_key in list(_PENDING_FORMS.keys()):
forms = _PENDING_FORMS[session_key]
expired_tokens = [token for token, data in forms.items() if data.get('_expires_at', 0) <= now]
for token in expired_tokens:
forms.pop(token, None)
if not forms:
_PENDING_FORMS.pop(session_key, None)
def _set_pending_form(session_key: str, form_data: dict[str, typing.Any]) -> None:
_prune_pending_forms()
stored = dict(form_data)
expiration_time = stored.get('expiration_time')
try:
expiration_ts = float(expiration_time) if expiration_time is not None else 0.0
except (TypeError, ValueError):
expiration_ts = 0.0
stored['_expires_at'] = expiration_ts or (time.time() + _PENDING_FORM_DEFAULT_TTL)
form_token = str(stored.get('form_token') or '')
forms = _PENDING_FORMS.setdefault(session_key, OrderedDict())
# Re-insert at the end so this becomes the "latest" entry
forms.pop(form_token, None)
forms[form_token] = stored
def _get_pending_form_by_token(session_key: str, form_token: str) -> dict[str, typing.Any] | None:
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms or not form_token:
return None
return forms.get(form_token)
def _get_pending_form_by_w_suffix(session_key: str, w_suffix: str) -> dict[str, typing.Any] | None:
"""Look up a pending form whose workflow_run_id ends with the given suffix.
Used by adapters (e.g. Telegram) whose callback payload is too small to
carry the full form_token / workflow_run_id.
"""
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms or not w_suffix:
return None
for token in reversed(forms):
form = forms[token]
if str(form.get('workflow_run_id', '')).endswith(w_suffix):
return form
return None
def _get_latest_pending_form(session_key: str) -> dict[str, typing.Any] | None:
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms:
return None
return forms[next(reversed(forms))]
def _iter_pending_forms(session_key: str) -> typing.Iterator[dict[str, typing.Any]]:
"""Iterate pending forms for a session, newest-first."""
_prune_pending_forms()
forms = _PENDING_FORMS.get(session_key)
if not forms:
return
for token in reversed(list(forms.keys())):
yield forms[token]
def _clear_pending_form(session_key: str, form_token: str | None = None) -> None:
"""Clear one specific pending form (by token) or all forms for the session."""
forms = _PENDING_FORMS.get(session_key)
if not forms:
return
if form_token is None:
_PENDING_FORMS.pop(session_key, None)
return
forms.pop(form_token, None)
if not forms:
_PENDING_FORMS.pop(session_key, None)
@runner.runner_class('dify-service-api')
class DifyServiceAPIRunner(runner.RequestRunner):
"""Dify Service API 对话请求器"""
@@ -433,140 +335,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id']
async def _submit_workflow_form_blocking(
self, form_action: dict
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""Submit human input to resume a paused Dify workflow (non-streaming)."""
form_token = form_action['form_token']
workflow_run_id = form_action['workflow_run_id']
user = form_action['user']
action_id = form_action.get('action_id', '')
inputs = form_action.get('inputs', {})
async for chunk in self.dify_client.workflow_submit(
form_token=form_token,
workflow_run_id=workflow_run_id,
inputs=inputs,
user=user,
action=action_id,
timeout=120,
):
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
if chunk['event'] == 'workflow_finished':
if chunk['data'].get('error'):
raise errors.DifyAPIError(chunk['data']['error'])
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
yield provider_message.Message(
role='assistant',
content=content,
)
def _resolve_pending_form(self, session_key: str, form_action: dict) -> dict | None:
"""Locate the pending form this action targets.
Tries identifiers in order of specificity: form_token, full
workflow_run_id, workflow_run_id suffix (Telegram-style compact id),
then falls back to the newest pending form for the session.
"""
form_token = form_action.get('form_token')
if form_token:
form = _get_pending_form_by_token(session_key, form_token)
if form:
return form
workflow_run_id = form_action.get('workflow_run_id')
if workflow_run_id:
for form in _iter_pending_forms(session_key):
if form.get('workflow_run_id') == workflow_run_id:
return form
w_suffix = form_action.get('w_suffix')
if w_suffix:
form = _get_pending_form_by_w_suffix(session_key, w_suffix)
if form:
return form
return _get_latest_pending_form(session_key)
def _merge_pending_form_action(self, session_key: str, form_action: dict | None) -> dict | None:
"""Backfill resume fields from the matching pending form."""
if not form_action:
return None
merged_action = dict(form_action)
merged_action.pop('w_suffix', None)
pending_form = self._resolve_pending_form(session_key, form_action)
if pending_form:
merged_action['form_token'] = merged_action.get('form_token') or pending_form.get('form_token', '')
merged_action['workflow_run_id'] = merged_action.get('workflow_run_id') or pending_form.get(
'workflow_run_id', ''
)
merged_action.setdefault('inputs', pending_form.get('inputs', {}))
merged_action.setdefault('user', pending_form.get('user', ''))
merged_action.setdefault('node_title', pending_form.get('node_title', ''))
# Resolve clicked action's display title from the stored actions list
if 'action_title' not in merged_action:
clicked_id = merged_action.get('action_id', '')
for action in pending_form.get('actions', []):
if str(action.get('id', '')) == str(clicked_id):
merged_action['action_title'] = action.get('title', clicked_id)
break
return merged_action
def _match_pending_form_action(self, session_key: str, user_text: str) -> dict | None:
"""Match plain text replies against pending Dify form actions.
Iterates all pending forms newest-first; the first action whose
title/id matches the text wins. This means when multiple forms are
pending with the same button label, the most recent one resolves.
"""
normalized_text = user_text.strip().lower()
if not normalized_text:
return None
for pending_form in _iter_pending_forms(session_key):
for action in pending_form.get('actions', []):
titles = {
str(action.get('title', '')).strip().lower(),
str(action.get('id', '')).strip().lower(),
}
if normalized_text in titles:
return {
'form_token': pending_form.get('form_token', ''),
'workflow_run_id': pending_form.get('workflow_run_id', ''),
'action_id': action.get('id', ''),
'action_title': action.get('title', action.get('id', '')),
'node_title': pending_form.get('node_title', ''),
'inputs': pending_form.get('inputs', {}),
'user': pending_form.get('user', ''),
}
return None
async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用工作流"""
# Check if this is a form action resume (button click or text match)
form_action_raw = query.variables.get('_dify_form_action')
session_key = _session_key_from_query(query)
if form_action_raw:
form_action = self._merge_pending_form_action(session_key, form_action_raw)
else:
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
if form_action:
_clear_pending_form(session_key, form_action.get('form_token') or None)
async for msg in self._submit_workflow_form_blocking(form_action):
yield msg
return
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
@@ -593,7 +366,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
}
inputs.update(query.variables)
human_input_yielded = False
async for chunk in self.dify_client.workflow_run(
inputs=inputs,
@@ -605,46 +377,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['event'] in ignored_events:
continue
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
workflow_run_id = chunk['data'].get('workflow_run_id', '')
for reason in reasons:
if reason.get('TYPE') == 'human_input_required':
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
node_title = reason.get('node_title', '')
_set_pending_form(
_session_key_from_query(query),
{
'workflow_run_id': workflow_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': node_title,
'form_content': form_content,
'inputs': reason.get('inputs', {}),
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
},
)
query.variables['_dify_form_render'] = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
}
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
human_input_yielded = True
yield provider_message.Message(
role='assistant',
content=display_text,
)
if chunk['event'] == 'node_started':
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
continue
@@ -667,8 +399,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
elif chunk['event'] == 'workflow_finished':
if human_input_yielded:
break
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
@@ -906,153 +636,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id']
async def _submit_workflow_form(
self, form_action: dict
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Submit human input to resume a paused Dify workflow."""
form_token = form_action['form_token']
workflow_run_id = form_action['workflow_run_id']
user = form_action['user']
action_id = form_action.get('action_id', '')
action_title = form_action.get('action_title', '') or action_id
node_title = form_action.get('node_title', '')
inputs = form_action.get('inputs', {})
messsage_idx = 0
is_final = False
think_start = False
think_end = False
workflow_contents = ''
repause_form_data: dict | None = None
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
async for chunk in self.dify_client.workflow_submit(
form_token=form_token,
workflow_run_id=workflow_run_id,
inputs=inputs,
user=user,
action=action_id,
timeout=120,
):
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
yield_this_iteration = False
if chunk['event'] == 'workflow_finished':
is_final = True
yield_this_iteration = True
if chunk['data'].get('error'):
raise errors.DifyAPIError(chunk['data']['error'])
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
new_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
for reason in reasons:
if reason.get('TYPE') != 'human_input_required':
continue
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
# Use a distinct name — `node_title` (the just-resolved step)
# must keep its value so the resume notice on the previous
# card still shows which step the user acted on.
paused_node_title = reason.get('node_title', '')
raw_inputs = reason.get('inputs', {})
_set_pending_form(
user,
{
'workflow_run_id': new_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': paused_node_title,
'form_content': form_content,
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': user,
},
)
repause_form_data = {
'form_content': form_content,
'actions': actions,
'node_title': paused_node_title,
'workflow_run_id': new_run_id,
'form_token': reason.get('form_token', ''),
}
# Ensure the final chunk has non-empty content so
# ResponseWrapper (which skips empty-content chunks) lets it
# propagate to SendResponseBackStage. Use a zero-width space
# so neither Lark nor Telegram renders visible noise — the
# adapter substitutes its own card text from _form_data.
if not workflow_contents:
workflow_contents = ''
is_final = True
yield_this_iteration = True
break
if chunk['event'] == 'text_chunk':
messsage_idx += 1
if remove_think:
if '<think>' in chunk['data']['text'] and not think_start:
think_start = True
continue
if '</think>' in chunk['data']['text'] and not think_end:
import re
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
workflow_contents += content
think_end = True
elif think_end:
workflow_contents += chunk['data']['text']
if think_start:
continue
else:
workflow_contents += chunk['data']['text']
if messsage_idx % 8 == 0:
yield_this_iteration = True
if yield_this_iteration:
msg = provider_message.MessageChunk(
role='assistant',
content=workflow_contents,
is_final=is_final,
)
msg._resume_from_form = True
if action_title:
msg._resume_action_title = action_title
if node_title:
msg._resume_node_title = node_title
if is_final and repause_form_data:
msg._form_data = repause_form_data
msg._open_new_card = True
yield msg
if is_final:
return
async def _workflow_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用工作流"""
# Check if this is a form action resume (button click or text match)
form_action_raw = query.variables.get('_dify_form_action')
session_key = _session_key_from_query(query)
if form_action_raw:
form_action = self._merge_pending_form_action(session_key, form_action_raw)
else:
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
if form_action:
_clear_pending_form(session_key, form_action.get('form_token') or None)
# Resume paused workflow via submit endpoint
async for msg in self._submit_workflow_form(form_action):
yield msg
return
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
@@ -1084,13 +672,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
think_start = False
think_end = False
workflow_contents = ''
workflow_run_id = ''
human_input_yielded = False
# Saved form data to attach to the final MessageChunk so the adapter
# can detect it when is_final=True and render buttons.
pending_form_data = None
display_text = ''
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
async for chunk in self.dify_client.workflow_run(
@@ -1101,62 +682,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
):
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
if chunk['event'] in ignored_events:
if chunk['event'] == 'workflow_started':
workflow_run_id = chunk['data'].get('workflow_run_id', '')
continue
if chunk['event'] == 'workflow_paused':
reasons = chunk['data'].get('reasons', [])
workflow_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
for reason in reasons:
if reason.get('TYPE') == 'human_input_required':
form_content = reason.get('form_content', '')
actions = reason.get('actions', [])
node_title = reason.get('node_title', '')
# Persist form state in module-level store keyed by session
raw_inputs = reason.get('inputs', {})
_set_pending_form(
_session_key_from_query(query),
{
'workflow_run_id': workflow_run_id,
'form_id': reason.get('form_id'),
'form_token': reason.get('form_token'),
'node_id': reason.get('node_id'),
'node_title': node_title,
'form_content': form_content,
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
'actions': actions,
'expiration_time': reason.get('expiration_time'),
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
},
)
# Pass form render metadata to downstream stages
query.variables['_dify_form_render'] = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
}
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
workflow_contents += display_text + '\n'
# Save form data to attach to the final chunk later.
# We do NOT yield here — the form content will be sent
# as the final MessageChunk (with is_final=True and
# _form_data) so the adapter can update the card and
# add buttons in one pass.
pending_form_data = {
'form_content': form_content,
'actions': actions,
'node_title': node_title,
'workflow_run_id': workflow_run_id,
'form_token': reason.get('form_token', ''),
}
human_input_yielded = True
if chunk['event'] == 'workflow_finished':
is_final = True
if chunk['data']['error']:
@@ -1204,29 +730,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
if messsage_idx % 8 == 0 or is_final:
final_content = workflow_contents if workflow_contents.strip() else ''
msg = provider_message.MessageChunk(
yield provider_message.MessageChunk(
role='assistant',
content=final_content,
content=workflow_contents,
is_final=is_final,
)
# Attach form data to the final chunk for the adapter
if is_final and pending_form_data:
msg._form_data = pending_form_data
pending_form_data = None
yield msg
# If the stream ended after workflow_paused without a
# workflow_finished event, yield a final chunk so the adapter
# can update the card and add buttons.
if human_input_yielded and not is_final:
msg = provider_message.MessageChunk(
role='assistant',
content=workflow_contents or display_text,
is_final=True,
)
msg._form_data = pending_form_data
yield msg
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""

View File

@@ -1,12 +1,8 @@
from __future__ import annotations
import posixpath
import re
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote
if TYPE_CHECKING:
from langbot.pkg.core import app
from typing import Any
from langbot.pkg.core import app
class RAGRuntimeService:
@@ -113,17 +109,8 @@ class RAGRuntimeService:
regardless of the underlying storage provider.
"""
# Validate storage_path to prevent path traversal
decoded_path = unquote(storage_path).replace('\\', '/')
decoded_segments = decoded_path.split('/')
normalized = posixpath.normpath(decoded_path)
if (
not storage_path
or '\x00' in decoded_path
or normalized.startswith('/')
or '..' in decoded_segments
or '..' in normalized.split('/')
or re.match(r'^[A-Za-z]:/', normalized)
):
normalized = posixpath.normpath(storage_path)
if normalized.startswith('/') or '..' in normalized.split('/'):
raise ValueError('Invalid storage path')
content_bytes = await self.ap.storage_mgr.storage_provider.load(normalized)
return content_bytes if content_bytes else b''

View File

@@ -13,11 +13,12 @@ class TelemetryManager:
await telemetry.send({ ... })
"""
send_tasks: list[asyncio.Task] = []
def __init__(self, ap: core_app.Application):
self.ap = ap
self.telemetry_config = {}
self.send_tasks: list[asyncio.Task] = []
async def initialize(self):
self.telemetry_config = self.ap.instance_config.data.get('space', {})

View File

@@ -83,7 +83,7 @@ def get_func_schema(function: typing.Callable) -> dict:
parameters['properties'][param.name] = {
'type': param_type,
'description': args_doc.get(param.name, ''),
'description': args_doc[param.name],
}
# add schema for array

View File

@@ -145,8 +145,7 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
"""获取QQ图片的下载链接"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
scheme = parsed.scheme or 'http'
return f'{scheme}://{parsed.netloc}{parsed.path}', query
return f'http://{parsed.netloc}{parsed.path}', query
async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, str]:

View File

@@ -23,10 +23,7 @@ def run_pip(params: list):
pipmain(params)
def install_requirements(file, extra_params: list | None = None):
if extra_params is None:
extra_params = []
def install_requirements(file, extra_params: list = []):
pipmain(
[
'install',

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import ipaddress
import re
from urllib.parse import urlparse
@@ -46,40 +44,6 @@ LOCAL_PATTERNS = [
'172.31.',
]
HOST_LABEL_PATTERN = re.compile(r'^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$')
def _is_valid_hostname(host: str) -> bool:
if host == 'localhost':
return True
try:
ipaddress.ip_address(host)
return True
except ValueError:
pass
if not host or len(host) > 253 or any(char.isspace() for char in host):
return False
host = host.rstrip('.')
if not host:
return False
return all(HOST_LABEL_PATTERN.match(label) for label in host.split('.'))
def _is_local_host(host: str) -> bool:
if host == 'localhost':
return True
try:
ip_address = ipaddress.ip_address(host)
except ValueError:
return False
return ip_address.is_private or ip_address.is_loopback or ip_address.is_unspecified
def get_runner_category(runner_name: str, runner_url: str) -> str:
if not runner_url:
@@ -88,15 +52,12 @@ def get_runner_category(runner_name: str, runner_url: str) -> str:
try:
parsed_url = urlparse(runner_url)
host = parsed_url.hostname.lower() if parsed_url.hostname else ''
_ = parsed_url.port
except Exception:
return RunnerCategory.UNKNOWN
if not parsed_url.scheme or not host or not _is_valid_hostname(host):
return RunnerCategory.UNKNOWN
if _is_local_host(host):
return RunnerCategory.LOCAL
for pattern in LOCAL_PATTERNS:
if host.startswith(pattern):
return RunnerCategory.LOCAL
for domain in CLOUD_DOMAINS:
if host.endswith(domain):

View File

@@ -2,48 +2,6 @@
This directory contains the test suite for LangBot, with a focus on comprehensive unit testing of pipeline stages.
## Quality Gate Layers
LangBot uses a layered quality gate system for developers and CI:
| Layer | Command | What it runs | When to use |
|-------|---------|--------------|-------------|
| **Quick** | `make test-quick` or `bash scripts/test-quick.sh` | Ruff lint + Unit tests + Smoke tests | Before every commit |
| **Fast Integration** | `make test-integration-fast` or `bash scripts/test-integration-fast.sh` | SQLite/API/Pipeline integration (no external services) | Before PR, weekly |
| **Coverage Gate** | `make test-coverage` or `bash scripts/test-coverage.sh` | All tests with coverage, threshold: 18% | Before merge, CI |
| **Full Local** | `make test-all-local` | Quick + Integration + Coverage | Before major changes |
**Note**: PostgreSQL migration tests and slow tests are NOT in local default gates. They run in separate CI workflows.
### Developer Workflow
```bash
# Daily: Quick self-test
bash scripts/test-quick.sh
# Before PR: Full local gate
make test-all-local
# Or run each layer separately:
bash scripts/test-quick.sh # ~2 min
bash scripts/test-integration-fast.sh # ~3 min
bash scripts/test-coverage.sh # ~8 min
```
### Coverage Baseline
Current coverage threshold: **18%**
Actual coverage: **30%**
This is a conservative baseline to prevent coverage regression. It does NOT represent the final quality target. Key modules have higher coverage:
- `pipeline.preproc.preproc`: 53%
- `pipeline.process.process`: 96%
- `pipeline.respback.respback`: 88%
- `telemetry.telemetry`: 87%
- `provider.session.sessionmgr`: 100%
- `provider.tools.toolmgr`: 83%
- `storage.providers.s3storage`: 80%
## Important Note
Due to circular import dependencies in the pipeline module structure, the test files use **lazy imports** via `importlib.import_module()` instead of direct imports. This ensures tests can run without triggering circular import errors.
@@ -52,81 +10,19 @@ Due to circular import dependencies in the pipeline module structure, the test f
```
tests/
├── __init__.py
├── factories/ # Shared test factories
│ ├── __init__.py # Factory exports
│ ├── app.py # FakeApp factory
│ ├── message.py # Message/query factories
│ ├── provider.py # FakeProvider factory
── platform.py # FakePlatform factory
├── integration/ # Integration tests (real resources)
│ ├── __init__.py
── api/ # HTTP API tests
├── __init__.py
│ │ └── test_smoke.py # API smoke tests
│ ├── pipeline/ # Pipeline stage-chain tests
│ │ ├── __init__.py
│ │ └── test_full_flow.py # Full flow integration
│ └── persistence/ # Database/persistence tests
│ ├── __init__.py
│ └── test_migrations.py # Alembic migration tests
├── smoke/ # Smoke tests (quick validation)
│ └── test_fake_message_flow.py
├── unit_tests/ # Unit tests
│ ├── box/ # Box module tests
│ ├── config/ # Configuration tests
│ ├── pipeline/ # Pipeline stage tests
│ │ └── conftest.py # Shared fixtures and test infrastructure
│ ├── platform/ # Platform adapter tests
│ ├── plugin/ # Plugin system tests
│ │ └── test_handler_actions.py # Action handler tests
│ ├── provider/ # Provider tests
│ │ ├── test_session_manager.py # SessionManager tests
│ │ └── test_tool_manager.py # ToolManager tests
│ ├── rag/ # RAG tests
│ │ └── test_file_storage.py # File/ZIP storage tests
│ ├── storage/ # Storage tests
│ │ └── test_s3storage.py # S3StorageProvider tests
│ ├── vector/ # Vector tests
│ │ └── test_vdb_filter_conversion.py # VDB filter tests
│ └── telemetry/ # Telemetry tests (rewritten)
├── utils/ # Test utilities
│ ├── __init__.py
│ └── import_isolation.py # sys.modules isolation for circular imports
└── README.md # This file
├── pipeline/ # Pipeline stage tests
│ ├── conftest.py # Shared fixtures and test infrastructure
│ ├── test_simple.py # Basic infrastructure tests (always pass)
│ ├── test_bansess.py # BanSessionCheckStage tests
│ ├── test_ratelimit.py # RateLimit stage tests
│ ├── test_preproc.py # PreProcessor stage tests
── test_respback.py # SendResponseBackStage tests
│ ├── test_resprule.py # GroupRespondRuleCheckStage tests
│ ├── test_pipelinemgr.py # PipelineManager tests
── test_stages_integration.py # Integration tests
└── README.md # This file
```
## Test Factories
The `tests/factories/` package provides reusable test factories:
```python
from tests.factories import (
FakeApp, # Mock application
FakeProvider, # Fake LLM provider
FakePlatform, # Fake platform adapter
text_query, # Create text query
group_text_query, # Create group query
command_query, # Create command query
)
# Create fake app
app = FakeApp()
# Create query with text
query = text_query("hello world")
# Create fake provider that returns specific response
provider = FakeProvider().returns("test response")
# Create fake platform for outbound capture
platform = FakePlatform()
await platform.reply_message(query.message_event, reply_chain)
outbound = platform.get_outbound_messages()
```
See `tests/factories/__init__.py` for all available factories.
## Test Architecture
### Fixtures (`conftest.py`)
@@ -147,28 +43,7 @@ The test suite uses a centralized fixture system that provides:
## Running Tests
### Quick self-test for developers
For local branch validation without real provider keys:
```bash
make test-quick
```
or
```bash
bash scripts/test-quick.sh
```
This runs:
1. Ruff lint check
2. Unit tests
3. Smoke tests
Suitable for quick validation before committing.
### Using the test runner script (recommended for full coverage)
### Using the test runner script (recommended)
```bash
bash run_tests.sh
```
@@ -181,135 +56,38 @@ This script automatically:
### Manual test execution
#### Run all unit tests
#### Run all tests
```bash
uv run pytest tests/unit_tests/ --cov=langbot --cov-report=xml --cov-report=term
pytest tests/pipeline/
```
#### Run specific test module
#### Run only simple tests (no imports, always pass)
```bash
uv run pytest tests/unit_tests/pipeline/ -v
pytest tests/pipeline/test_simple.py -v
```
#### Run specific test file
```bash
uv run pytest tests/unit_tests/pipeline/test_bansess.py -v
pytest tests/pipeline/test_bansess.py -v
```
#### Run with coverage
```bash
uv run pytest tests/unit_tests/pipeline/ --cov=langbot --cov-report=html
pytest tests/pipeline/ --cov=pkg/pipeline --cov-report=html
```
#### Run specific test
```bash
uv run pytest tests/unit_tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
pytest tests/pipeline/test_bansess.py::test_bansess_whitelist_allow -v
```
### Using markers
```bash
# Run only unit tests
uv run pytest tests/unit_tests/ -m unit
# Run only integration tests
uv run pytest tests/integration/ -m integration
# Run integration tests excluding slow ones
uv run pytest tests/integration/ -m "not slow" -q
# Skip slow tests
uv run pytest tests/unit_tests/ -m "not slow"
```
### Running integration tests
Integration tests validate real system behavior with actual database/network resources.
```bash
# Run all integration tests (excluding slow ones)
uv run pytest tests/integration/ -m "not slow" -q
# Run SQLite migration integration tests
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
# Run API smoke integration tests
uv run pytest tests/integration/api/test_smoke.py -q
# Run pipeline full-flow integration tests
uv run pytest tests/integration/pipeline/test_full_flow.py -q
# Run with verbose output
uv run pytest tests/integration/ -v
```
Note: Integration tests use:
- Temporary databases (tmp_path) for persistence tests
- Fake app/services for API tests (no real provider/platform)
- Fake runner/provider for pipeline tests (no real LLM API)
- Do not require external services
### Running migration tests locally
SQLite migration tests can be run locally without any external dependencies:
```bash
# SQLite migration tests (uses tmp_path, no external DB needed)
uv run pytest tests/integration/persistence/test_migrations.py -q --tb=short
```
PostgreSQL migration tests require an external PostgreSQL database:
```bash
# PostgreSQL migration tests (requires PostgreSQL service)
# Tests are marked as slow and skipped if TEST_POSTGRES_URL is not set
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
# Or skip by default (no PostgreSQL available)
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q --tb=short
# Output: SKIPPED (TEST_POSTGRES_URL not set)
```
Note: PostgreSQL tests are **not** included in fast integration gate because they:
- Require external PostgreSQL service
- Are marked with `@pytest.mark.slow`
- Need `TEST_POSTGRES_URL` environment variable
CI workflow `.github/workflows/test-migrations.yml` runs:
- SQLite tests in `test-migrations-sqlite` job (fast, no external services)
- PostgreSQL tests in `test-migrations-postgres` job (uses PostgreSQL service container)
### Running pipeline integration tests locally
Pipeline full-flow integration tests validate real stage interactions:
```bash
# Run pipeline integration tests (uses fake runner, no real LLM API)
uv run pytest tests/integration/pipeline/test_full_flow.py -q --tb=short
# Run with coverage for pipeline modules
uv run pytest tests/integration/pipeline \
--cov=langbot.pkg.pipeline.preproc.preproc \
--cov=langbot.pkg.pipeline.process.process \
--cov=langbot.pkg.pipeline.respback.respback \
--cov-report=term -q
```
These tests:
- Use `FakeRunner` class to simulate LLM responses without real API calls
- Import real `PreProcessor`, `MessageProcessor`, `SendResponseBackStage` stages
- Validate stage chain: PreProcessor → Processor → SendResponseBackStage
- Test prevent_default, exception handling, and full message flow
- Do not require real LLM provider keys
### Known Issues
Some tests may encounter circular import errors. This is a known issue with the current module structure. The test infrastructure is designed to work around this using lazy imports, but if you encounter issues:
1. Make sure you're running from the project root directory
2. Ensure dependencies are installed: `uv sync --dev`
3. Try running a simple test first to verify the test infrastructure works
2. Ensure the virtual environment is activated
3. Try running `test_simple.py` first to verify the test infrastructure works
## CI/CD Integration
@@ -319,7 +97,7 @@ Tests are automatically run on:
- Push to PR branch
- Push to master/develop branches
The workflow runs tests on Python 3.11, 3.12, and 3.13 to ensure compatibility.
The workflow runs tests on Python 3.10, 3.11, and 3.12 to ensure compatibility.
## Adding New Tests
@@ -333,8 +111,8 @@ Create a new test file `test_<stage_name>.py`:
"""
import pytest
from langbot.pkg.pipeline.<module>.<stage> import <StageClass>
from langbot.pkg.pipeline import entities as pipeline_entities
from pkg.pipeline.<module>.<stage> import <StageClass>
from pkg.pipeline import entities as pipeline_entities
@pytest.mark.asyncio
@@ -350,7 +128,7 @@ async def test_stage_basic_flow(mock_app, sample_query):
### 2. For additional fixtures
Add new fixtures to the appropriate `conftest.py`:
Add new fixtures to `conftest.py`:
```python
@pytest.fixture
@@ -364,7 +142,7 @@ def my_custom_fixture():
Use the helper functions in `conftest.py`:
```python
from tests.unit_tests.pipeline.conftest import create_stage_result, assert_result_continue
from tests.pipeline.conftest import create_stage_result, assert_result_continue
result = create_stage_result(
result_type=pipeline_entities.ResultType.CONTINUE,
@@ -388,7 +166,7 @@ assert_result_continue(result)
### Import errors
Make sure you've installed the package in development mode:
```bash
uv sync --dev
uv pip install -e .
```
### Async test failures
@@ -399,11 +177,7 @@ Check that you're mocking at the right level and using `AsyncMock` for async fun
## Future Enhancements
- [x] Add integration tests for database migrations (SQLite)
- [x] Add PostgreSQL migration integration tests (G-003)
- [x] Add integration tests for full pipeline execution
- [x] Add API smoke integration tests
- [ ] Add E2E tests
- [ ] Add integration tests for full pipeline execution
- [ ] Add performance benchmarks
- [ ] Add mutation testing for better coverage quality
- [ ] Add property-based testing with Hypothesis
- [ ] Add property-based testing with Hypothesis

View File

@@ -1,102 +0,0 @@
"""E2E test fixtures.
Provides fixtures for starting real LangBot process with minimal configuration.
"""
from __future__ import annotations
import pytest
import tempfile
import shutil
import logging
from pathlib import Path
from tests.e2e.utils.config_factory import create_minimal_config, create_test_directories
from tests.e2e.utils.process_manager import LangBotProcess, find_project_root
logger = logging.getLogger(__name__)
pytestmark = pytest.mark.e2e
@pytest.fixture(scope='session')
def e2e_port():
"""Port for E2E testing (non-default to avoid conflicts)."""
return 15300
@pytest.fixture(scope='session')
def e2e_tmpdir():
"""Create temporary directory for E2E testing."""
tmpdir = Path(tempfile.mkdtemp(prefix='langbot_e2e_'))
logger.info(f'E2E tmpdir: {tmpdir}')
yield tmpdir
# Cleanup
logger.info(f'Cleaning up E2E tmpdir: {tmpdir}')
shutil.rmtree(tmpdir, ignore_errors=True)
@pytest.fixture(scope='session')
def e2e_config_path(e2e_tmpdir, e2e_port):
"""Create minimal config.yaml for E2E testing."""
config_path = create_minimal_config(e2e_tmpdir, port=e2e_port)
create_test_directories(e2e_tmpdir)
logger.info(f'E2E config: {config_path}')
return config_path
@pytest.fixture(scope='session')
def langbot_process(e2e_config_path, e2e_port, e2e_tmpdir):
"""Start real LangBot process for E2E testing.
This fixture starts LangBot once per session and reuses it for all tests.
Coverage data is collected from the subprocess.
"""
project_root = find_project_root()
collect_coverage = True
proc = LangBotProcess(
project_root=project_root,
work_dir=e2e_tmpdir, # Run in tmpdir where data/config.yaml exists
port=e2e_port,
timeout=60, # Longer timeout for first startup
collect_coverage=collect_coverage,
)
success = proc.start()
if not success:
stdout, stderr = proc.get_logs()
pytest.fail(f'LangBot failed to start:\nstdout: {stdout}\nstderr: {stderr}')
yield proc
# Cleanup
proc.stop()
# Combine coverage data if collected
if collect_coverage and proc.get_coverage_file():
coverage_file = proc.get_coverage_file()
if coverage_file.exists():
# Copy coverage data to project root for combining
target = project_root / '.coverage.e2e'
shutil.copy(coverage_file, target)
logger.info(f'Coverage data saved to: {target}')
@pytest.fixture
def e2e_client(e2e_port, langbot_process):
"""HTTP client for E2E testing."""
import httpx
base_url = f'http://127.0.0.1:{e2e_port}'
with httpx.Client(base_url=base_url, timeout=10.0) as client:
yield client
@pytest.fixture(scope='session')
def e2e_db_path(e2e_tmpdir):
"""Path to SQLite database file."""
return e2e_tmpdir / 'data' / 'langbot.db'

View File

@@ -1,142 +0,0 @@
"""E2E tests for LangBot startup flow.
Tests the complete startup process including:
- boot.py startup orchestration
- stages/ (build_app, load_config, migrate, etc.)
- database initialization
- API availability
Run: uv run pytest tests/e2e/test_startup.py -v -m e2e
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.e2e
class TestStartupFlow:
"""Tests for LangBot startup process."""
def test_process_is_running(self, langbot_process):
"""Verify LangBot process is running."""
assert langbot_process.is_running()
def test_health_check(self, langbot_process, e2e_port):
"""Verify LangBot API is responding."""
assert langbot_process.health_check()
def test_system_info_endpoint(self, e2e_client):
"""Test /api/v1/system/info endpoint."""
response = e2e_client.get('/api/v1/system/info')
assert response.status_code == 200
data = response.json()
assert data['code'] == 0
assert 'data' in data
# System info should contain version info
assert 'version' in data['data'] or 'edition' in data['data']
def test_database_initialized(self, e2e_db_path):
"""Verify SQLite database was created and initialized."""
assert e2e_db_path.exists()
# Database should have some tables after migration
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
# Check that core tables exist
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
# Core tables should be created by Alembic migrations
# Note: table names may differ (legacy_pipelines instead of pipelines)
expected_tables = ['legacy_pipelines', 'bots', 'model_providers', 'llm_models']
for table in expected_tables:
assert table in tables, f'Table {table} should exist. Available: {tables}'
conn.close()
def test_chroma_directory_created(self, e2e_tmpdir):
"""Verify Chroma vector database directory was created."""
chroma_path = e2e_tmpdir / 'chroma'
# Created by the E2E config factory before startup.
assert chroma_path.exists()
def test_pipelines_endpoint(self, e2e_client):
"""Test /api/v1/pipelines endpoint (requires auth)."""
# Without auth, should return 401
response = e2e_client.get('/api/v1/pipelines')
assert response.status_code == 401
def test_auth_endpoint(self, e2e_client, e2e_tmpdir):
"""Test auth endpoint."""
# First startup may allow initial setup
response = e2e_client.post('/api/v1/user/auth', json={
'username': 'admin',
'password': 'admin',
})
# Response could be:
# - 200 if auth succeeds
# - 400 if credentials wrong
# - 401 if user not initialized
assert response.status_code in [200, 400, 401]
class TestStartupStages:
"""Tests that verify individual startup stages worked correctly."""
def test_config_loaded(self, e2e_client):
"""Verify config was loaded correctly by checking API port."""
# If API responds on e2e_port, config was loaded
assert e2e_client.get('/api/v1/system/info').status_code == 200
def test_migrations_applied(self, e2e_db_path):
"""Verify database migrations were applied."""
import sqlite3
conn = sqlite3.connect(str(e2e_db_path))
cursor = conn.cursor()
# Check alembic_version table exists and has version
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='alembic_version';")
result = cursor.fetchone()
assert result is not None, 'alembic_version table should exist'
cursor.execute('SELECT version_num FROM alembic_version;')
version = cursor.fetchone()
assert version is not None, 'Migration version should be set'
conn.close()
def test_http_controller_initialized(self, e2e_client):
"""Verify HTTP controller was initialized."""
# Multiple endpoints should be available
endpoints = [
'/api/v1/system/info',
'/api/v1/pipelines',
'/api/v1/provider/providers',
'/api/v1/platform/bots',
]
for endpoint in endpoints:
response = e2e_client.get(endpoint)
# Should get a real route response, even if auth is required.
assert response.status_code in [200, 401, 403], f'{endpoint} should be registered'
class TestMinimalStartupNoLLM:
"""Tests verifying LangBot can start without LLM providers."""
def test_api_available_without_llm(self, e2e_client):
"""API should be available even without LLM providers configured."""
response = e2e_client.get('/api/v1/system/info')
assert response.status_code == 200
def test_pipeline_metadata_available(self, e2e_client):
"""Pipeline metadata endpoint should work without LLM."""
# Requires auth, but endpoint should exist
response = e2e_client.get('/api/v1/pipelines/_/metadata')
assert response.status_code in [200, 401] # Not 404 or 500

View File

@@ -1,179 +0,0 @@
"""E2E test configuration factory.
Generates minimal config.yaml for testing LangBot startup without external dependencies.
"""
from __future__ import annotations
import yaml
from pathlib import Path
def create_minimal_config(tmpdir: Path, port: int = 15300) -> Path:
"""Create minimal config.yaml for E2E testing.
Uses embedded databases (SQLite, Chroma) to avoid external dependencies.
Config is created at tmpdir/data/config.yaml (LangBot expects this location).
"""
# LangBot expects config at data/config.yaml
data_dir = tmpdir / 'data'
data_dir.mkdir(parents=True, exist_ok=True)
config = {
'admins': [],
'api': {
'port': port,
'webhook_prefix': f'http://127.0.0.1:{port}',
'extra_webhook_prefix': '',
},
'command': {
'enable': True,
'prefix': ['!', '!'],
'privilege': {},
},
'concurrency': {
'pipeline': 20,
'session': 1,
},
'proxy': {
'http': '',
'https': '',
},
'system': {
'instance_id': '',
'edition': 'community',
'recovery_key': '',
'allow_modify_login_info': True,
'disabled_adapters': [],
'limitation': {
'max_bots': -1,
'max_pipelines': -1,
'max_extensions': -1,
},
'task_retention': {
'completed_limit': 200,
},
'jwt': {
'expire': 604800,
'secret': 'e2e-test-secret-key',
},
},
'database': {
'use': 'sqlite',
'sqlite': {
'path': str(tmpdir / 'data' / 'langbot.db'),
},
'postgresql': {
'host': '127.0.0.1',
'port': 5432,
'user': 'postgres',
'password': 'postgres',
'database': 'postgres',
},
},
'vdb': {
'use': 'chroma', # Chroma is embedded, no external dependency
'chroma': {
'path': str(tmpdir / 'chroma'),
},
'qdrant': {
'url': '',
'host': 'localhost',
'port': 6333,
'api_key': '',
},
'seekdb': {
'mode': 'embedded',
'path': str(tmpdir / 'seekdb'),
'database': 'langbot',
'host': 'localhost',
'port': 2881,
'user': 'root',
'password': '',
'tenant': '',
},
'milvus': {
'uri': 'http://127.0.0.1:19530',
'token': '',
'db_name': '',
},
'pgvector': {
'host': '127.0.0.1',
'port': 5433,
'database': 'langbot',
'user': 'postgres',
'password': 'postgres',
},
},
'storage': {
'use': 'local',
'cleanup': {
'enabled': False, # Disable cleanup for tests
'check_interval_hours': 1,
'uploaded_file_retention_days': 7,
'log_retention_days': 3,
},
'local': {
'path': str(tmpdir / 'storage'),
},
's3': {
'endpoint_url': '',
'access_key_id': '',
'secret_access_key': '',
'region': 'us-east-1',
'bucket': 'langbot-storage',
},
},
'plugin': {
'enable': False, # Disable plugin system for minimal startup
'runtime_ws_url': '',
'enable_marketplace': False,
'display_plugin_debug_url': '',
'binary_storage': {
'max_value_bytes': 10485760,
},
},
'monitoring': {
'auto_cleanup': {
'enabled': False, # Disable cleanup for tests
'retention_days': 30,
'check_interval_hours': 1,
'delete_batch_size': 1000,
},
},
'space': {
'url': 'https://space.langbot.app',
'models_gateway_api_url': 'https://api.langbot.cloud/v1',
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
'disable_models_service': True, # Disable external services
'disable_telemetry': True, # Disable telemetry for tests
},
'provider': {}, # Empty providers - minimal startup
'llm': [], # Empty LLM models
}
# Ensure data directory exists (LangBot expects config at data/config.yaml)
data_dir = tmpdir / 'data'
data_dir.mkdir(parents=True, exist_ok=True)
# Write config to data/config.yaml (LangBot's expected location)
config_path = data_dir / 'config.yaml'
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False)
return config_path
def create_test_directories(tmpdir: Path) -> dict[str, Path]:
"""Create necessary directories for LangBot testing."""
directories = {
'data': tmpdir / 'data',
'logs': tmpdir / 'logs',
'storage': tmpdir / 'storage',
'chroma': tmpdir / 'chroma',
}
for path in directories.values():
path.mkdir(parents=True, exist_ok=True)
return directories

View File

@@ -1,204 +0,0 @@
"""E2E test process manager.
Manages LangBot subprocess lifecycle for E2E testing.
"""
from __future__ import annotations
import subprocess
import time
import signal
import os
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class LangBotProcess:
"""Manages a LangBot subprocess for E2E testing."""
def __init__(
self,
project_root: Path,
work_dir: Path,
port: int = 15300,
timeout: int = 30,
collect_coverage: bool = True,
):
self.project_root = project_root
self.work_dir = work_dir # Directory containing data/config.yaml
self.port = port
self.timeout = timeout
self.collect_coverage = collect_coverage
self.process: Optional[subprocess.Popen] = None
self._stdout_data: bytes = b''
self._stderr_data: bytes = b''
self._coverage_file: Optional[Path] = None
def start(self) -> bool:
"""Start LangBot process and wait for it to be ready."""
import httpx
# Prepare environment
env = os.environ.copy()
env['PYTHONPATH'] = str(self.project_root / 'src')
# Set API port via environment variable
env['API__PORT'] = str(self.port)
env['API__WEBHOOK_PREFIX'] = f'http://127.0.0.1:{self.port}'
# Disable telemetry
env['SPACE__DISABLE_TELEMETRY'] = 'true'
env['SPACE__DISABLE_MODELS_SERVICE'] = 'true'
# Build command
if self.collect_coverage:
# Use coverage.py to collect coverage data
# Set COVERAGE_PROCESS_START to enable coverage in subprocess
self._coverage_file = self.work_dir / '.coverage.e2e'
env['COVERAGE_PROCESS_START'] = str(self.project_root / '.coveragerc')
env['COVERAGE_FILE'] = str(self._coverage_file)
# Create .coveragerc for subprocess
coveragerc_content = """
[run]
source = langbot.pkg
parallel = True
data_file = {}
omit =
*/tests/*
*/test_*.py
[report]
precision = 2
""".format(str(self._coverage_file))
coveragerc_path = self.work_dir / '.coveragerc'
with open(coveragerc_path, 'w') as f:
f.write(coveragerc_content)
cmd = [
'coverage', 'run',
'--rcfile=' + str(coveragerc_path),
'-m', 'langbot',
]
else:
cmd = ['uv', 'run', 'python', '-m', 'langbot']
logger.info(f'Starting LangBot in: {self.work_dir}')
logger.info(f'Command: {cmd}')
# Start process (run in work_dir so it finds data/config.yaml)
self.process = subprocess.Popen(
cmd,
cwd=self.work_dir,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if os.name != 'nt' else None,
)
# Wait for startup
start_time = time.time()
while time.time() - start_time < self.timeout:
# Check if process died
if self.process.poll() is not None:
self._stdout_data, self._stderr_data = self.process.communicate()
logger.error(f'LangBot process died: {self._stderr_data.decode()}')
return False
# Try to connect
try:
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=2.0,
)
if r.status_code == 200:
logger.info(f'LangBot started successfully on port {self.port}')
return True
except (httpx.ConnectError, httpx.TimeoutException):
pass
time.sleep(1)
# Timeout
logger.error(f'LangBot startup timeout after {self.timeout}s')
self.stop()
return False
def stop(self) -> None:
"""Stop LangBot process gracefully."""
if self.process is None:
return
logger.info('Stopping LangBot process...')
# Try graceful shutdown first
if os.name != 'nt':
# Send SIGTERM to process group
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
else:
self.process.terminate()
# Wait for graceful shutdown
try:
self.process.wait(timeout=5)
logger.info('LangBot stopped gracefully')
except subprocess.TimeoutExpired:
# Force kill
logger.warning('Force killing LangBot process')
if os.name != 'nt':
os.killpg(os.getpgid(self.process.pid), signal.SIGKILL)
else:
self.process.kill()
self.process.wait()
# Collect output for debugging
if self.process.stdout or self.process.stderr:
self._stdout_data, self._stderr_data = self.process.communicate()
self.process = None
def is_running(self) -> bool:
"""Check if process is still running."""
return self.process is not None and self.process.poll() is None
def get_logs(self) -> tuple[str, str]:
"""Get stdout and stderr logs."""
stdout = self._stdout_data.decode('utf-8', errors='replace')
stderr = self._stderr_data.decode('utf-8', errors='replace')
return stdout, stderr
def get_coverage_file(self) -> Optional[Path]:
"""Get coverage data file path."""
return self._coverage_file
def health_check(self) -> bool:
"""Check if LangBot API is responding."""
import httpx
if not self.is_running():
return False
try:
r = httpx.get(
f'http://127.0.0.1:{self.port}/api/v1/system/info',
timeout=5.0,
)
return r.status_code == 200
except Exception:
return False
def find_project_root() -> Path:
"""Find LangBot project root directory."""
current = Path(__file__).resolve()
# Walk up until we find src/langbot
for parent in current.parents:
if (parent / 'src' / 'langbot').exists():
return parent
# Fallback to LangBot-test-build directory
return Path('/home/glwuy/langbot-app/LangBot-test-build')

View File

@@ -1,102 +0,0 @@
"""
Shared test factories for LangBot tests.
Provides reusable factories for:
- Fake application (app.py)
- Messages and queries (message.py)
- Fake providers (provider.py)
- Fake platforms (platform.py)
Usage:
from tests.factories import FakeApp, text_query, FakeProvider
app = FakeApp()
query = text_query("hello")
provider = FakeProvider.returns("response")
"""
from tests.factories.app import FakeApp, fake_app
from tests.factories.message import (
text_chain,
group_text_chain,
mention_chain,
image_chain,
text_query,
group_text_query,
private_text_query,
command_query,
mention_query,
empty_query,
image_query,
file_query,
unsupported_query,
voice_query,
at_all_query,
query_with_session,
query_with_config,
friend_message_event,
group_message_event,
mock_adapter,
)
from tests.factories.provider import (
FakeProvider,
fake_provider,
fake_provider_pong,
fake_provider_timeout,
fake_provider_auth_error,
fake_provider_rate_limit,
fake_provider_malformed,
fake_model,
)
from tests.factories.platform import (
FakePlatform,
fake_platform,
fake_platform_with_streaming,
fake_platform_with_failure,
mock_platform_adapter,
)
__all__ = [
# App
"FakeApp",
"fake_app",
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
# Message events
"friend_message_event",
"group_message_event",
# Mock adapters
"mock_adapter",
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
"image_query",
"file_query",
"unsupported_query",
"voice_query",
"at_all_query",
"query_with_session",
"query_with_config",
# Provider
"FakeProvider",
"fake_provider",
"fake_provider_pong",
"fake_provider_timeout",
"fake_provider_auth_error",
"fake_provider_rate_limit",
"fake_provider_malformed",
"fake_model",
# Platform
"FakePlatform",
"fake_platform",
"fake_platform_with_streaming",
"fake_platform_with_failure",
"mock_platform_adapter",
]

View File

@@ -1,137 +0,0 @@
"""
Fake application factory for tests.
Provides a mock Application object with all dependencies needed by pipeline stages.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
class FakeApp:
"""Mock Application object providing all basic dependencies needed by stages."""
def __init__(
self,
*,
command_prefix: list[str] = ["/", "!"],
command_enable: bool = True,
pipeline_concurrency: int = 10,
admins: list[str] | None = None,
**extra_attrs,
):
self.logger = self._create_mock_logger()
self.sess_mgr = self._create_mock_session_manager()
self.model_mgr = self._create_mock_model_manager()
self.tool_mgr = self._create_mock_tool_manager()
self.plugin_connector = self._create_mock_plugin_connector()
self.persistence_mgr = self._create_mock_persistence_manager()
self.query_pool = self._create_mock_query_pool()
self.instance_config = self._create_mock_instance_config(
command_prefix=command_prefix,
command_enable=command_enable,
pipeline_concurrency=pipeline_concurrency,
admins=admins or [],
)
self.task_mgr = self._create_mock_task_manager()
# Handler-specific optional attributes
self.telemetry = self._create_mock_telemetry()
self.survey = None
self.cmd_mgr = self._create_mock_cmd_mgr()
# Apply any extra attributes for specific test scenarios
for name, value in extra_attrs.items():
setattr(self, name, value)
# Captured outbound messages (for assertions)
self._outbound_messages: list = []
def _create_mock_logger(self):
logger = Mock()
logger.debug = Mock()
logger.info = Mock()
logger.error = Mock()
logger.warning = Mock()
return logger
def _create_mock_session_manager(self):
sess_mgr = AsyncMock()
sess_mgr.get_session = AsyncMock()
sess_mgr.get_conversation = AsyncMock()
return sess_mgr
def _create_mock_model_manager(self):
model_mgr = AsyncMock()
model_mgr.get_model_by_uuid = AsyncMock()
return model_mgr
def _create_mock_tool_manager(self):
tool_mgr = AsyncMock()
tool_mgr.get_all_tools = AsyncMock(return_value=[])
return tool_mgr
def _create_mock_plugin_connector(self):
plugin_connector = AsyncMock()
plugin_connector.emit_event = AsyncMock()
return plugin_connector
def _create_mock_persistence_manager(self):
persistence_mgr = AsyncMock()
persistence_mgr.execute_async = AsyncMock()
return persistence_mgr
def _create_mock_query_pool(self):
query_pool = Mock()
query_pool.cached_queries = {}
query_pool.queries = []
query_pool.condition = AsyncMock()
return query_pool
def _create_mock_instance_config(
self,
command_prefix: list[str],
command_enable: bool,
pipeline_concurrency: int,
admins: list[str],
):
instance_config = Mock()
instance_config.data = {
"command": {"prefix": command_prefix, "enable": command_enable},
"concurrency": {"pipeline": pipeline_concurrency},
"admins": admins,
}
return instance_config
def _create_mock_task_manager(self):
task_mgr = Mock()
task_mgr.create_task = Mock()
return task_mgr
def _create_mock_telemetry(self):
telemetry = AsyncMock()
telemetry.start_send_task = AsyncMock()
return telemetry
def _create_mock_cmd_mgr(self):
cmd_mgr = AsyncMock()
cmd_mgr.execute = AsyncMock()
return cmd_mgr
def capture_message(self, message):
"""Capture an outbound message for test assertions."""
self._outbound_messages.append(message)
def get_outbound_messages(self) -> list:
"""Get all captured outbound messages."""
return self._outbound_messages.copy()
def clear_outbound_messages(self):
"""Clear captured outbound messages."""
self._outbound_messages.clear()
def fake_app(**kwargs) -> FakeApp:
"""Create a FakeApp instance with optional overrides."""
return FakeApp(**kwargs)

View File

@@ -1,472 +0,0 @@
"""
Message and query factories for tests.
Provides reusable factories for creating message chains, events, and query objects.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import typing
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.provider.session as provider_session
# Counter for generating unique IDs
_query_counter = 0
def _next_query_id() -> int:
"""Generate a unique query ID."""
global _query_counter
_query_counter += 1
return _query_counter
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.Image(url=url))
return platform_message.MessageChain(components)
def command_chain(
command: str = "help",
prefix: str = "/",
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
# ============== Message Event Factories ==============
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
id=sender_id,
nickname=nickname,
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
id=group_id,
name=group_name,
permission=platform_entities.Permission.Member,
)
sender = platform_entities.GroupMember(
id=sender_id,
member_name=sender_name,
permission=platform_entities.Permission.Member,
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
# ============== Mock Adapter Factory ==============
def mock_adapter() -> Mock:
"""Create a mock platform adapter."""
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=False)
adapter.reply_message = AsyncMock()
adapter.reply_message_chunk = AsyncMock()
return adapter
# ============== Query Factories ==============
def _base_query(
message_chain: platform_message.MessageChain,
message_event: platform_events.MessageEvent,
launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
adapter: Mock,
**overrides,
) -> pipeline_query.Query:
"""Create a base query with model_construct to bypass validation."""
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
}
# Apply overrides
for key, value in overrides.items():
base_data[key] = value
return pipeline_query.Query.model_construct(**base_data)
def text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a basic text query (private chat)."""
chain = text_chain(text)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def private_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a private text query (alias for text_query)."""
return text_query(text, sender_id, **overrides)
def group_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a group text query."""
chain = text_chain(text)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def command_query(
command: str = "help",
prefix: str = "/",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a command-like query."""
chain = command_chain(command, prefix)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def mention_query(
text: str = "hello",
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a mention-bot query (group chat with @mention)."""
chain = mention_chain(text, target)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def empty_query(**overrides) -> pipeline_query.Query:
"""Create an empty message query."""
chain = platform_message.MessageChain([])
event = friend_message_event(chain)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
adapter=adapter,
**overrides,
)
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create an image query."""
chain = image_chain(text, url)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def file_query(
url: str = "https://example.com/document.pdf",
name: str = "document.pdf",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a file attachment query."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.File(url=url, name=name))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def unsupported_query(
unsupported_type: str = "CustomComponent",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a query with unsupported/unknown message segment."""
components = []
if text:
components.append(platform_message.Plain(text=text))
# Use Unknown component for unsupported types
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def query_with_session(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
session: provider_session.Session = None,
**overrides,
) -> pipeline_query.Query:
"""Create a query with a session object.
If session is None, creates a default session with empty conversation.
"""
if session is None:
# Create a default session
session = provider_session.Session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
use_prompt_name="default",
using_conversation=None,
conversations=[],
)
return text_query(text, sender_id, session=session, **overrides)
def query_with_config(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
pipeline_config: dict = None,
**overrides,
) -> pipeline_query.Query:
"""Create a query with custom pipeline configuration.
If pipeline_config is None, uses default config.
Useful for testing specific stage behaviors.
"""
if pipeline_config is None:
pipeline_config = {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
}
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
def voice_query(
url: str = "https://example.com/audio.mp3",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a voice/audio query."""
components = [
platform_message.Voice(url=url),
]
chain = platform_message.MessageChain(components)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def at_all_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a group query with @All mention."""
components = [
platform_message.AtAll(),
platform_message.Plain(text=f" {text}"),
]
chain = platform_message.MessageChain(components)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)

View File

@@ -1,336 +0,0 @@
"""
Fake platform factory for tests.
Provides a fake platform adapter for tests that need inbound message injection
and outbound message capture.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import typing
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
class FakePlatform:
"""Fake platform adapter for unit and integration tests.
Simulates platform behavior without real network calls:
- Inbound text message construction
- Group and private conversation identities
- Mention-bot flag
- Outbound text capture
- Outbound file/image capture
- Send failure simulation
Does not start real platform adapters.
Does not call IM platform SDKs.
"""
def __init__(
self,
*,
bot_account_id: str = "test-bot",
stream_output_supported: bool = False,
raise_error: Exception = None,
):
self.bot_account_id = bot_account_id
self._stream_output_supported = stream_output_supported
self._raise_error = raise_error
# Captured outbound messages
self._outbound_messages: list[dict] = []
self._outbound_chunks: list[dict] = []
# Registered listeners
self._listeners: dict = {}
def raises(self, error: Exception) -> "FakePlatform":
"""Configure platform to raise an error on send."""
self._raise_error = error
return self
def send_failure(self) -> "FakePlatform":
"""Configure platform to simulate send failure."""
return self.raises(Exception("Platform send failure"))
def supports_streaming(self, supported: bool = True) -> "FakePlatform":
"""Configure whether streaming output is supported."""
self._stream_output_supported = supported
return self
def get_outbound_messages(self) -> list[dict]:
"""Get all captured outbound messages for assertions."""
return self._outbound_messages.copy()
def get_outbound_chunks(self) -> list[dict]:
"""Get all captured outbound streaming chunks for assertions."""
return self._outbound_chunks.copy()
def clear_outbound(self):
"""Clear captured outbound messages."""
self._outbound_messages.clear()
self._outbound_chunks.clear()
def last_message(self) -> dict | None:
"""Get the last captured outbound message."""
return self._outbound_messages[-1] if self._outbound_messages else None
def last_chunk(self) -> dict | None:
"""Get the last captured streaming chunk."""
return self._outbound_chunks[-1] if self._outbound_chunks else None
# ============== Inbound Message Construction ==============
def create_friend_message(
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
) -> platform_events.FriendMessage:
"""Create an inbound friend (private) message event."""
sender = platform_entities.Friend(
id=sender_id,
nickname=nickname,
remark=None,
)
chain = platform_message.MessageChain([
platform_message.Plain(text=text),
])
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
def create_group_message(
self,
text: str,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
mention_bot: bool = False,
) -> platform_events.GroupMessage:
"""Create an inbound group message event.
Args:
text: Message text content
sender_id: Sender user ID
sender_name: Sender display name
group_id: Group ID
group_name: Group name
mention_bot: If True, prepend @mention of bot account
"""
group = platform_entities.Group(
id=group_id,
name=group_name,
permission=platform_entities.Permission.Member,
)
sender = platform_entities.GroupMember(
id=sender_id,
member_name=sender_name,
permission=platform_entities.Permission.Member,
group=group,
)
# Build message chain with optional mention
components = []
if mention_bot:
components.append(platform_message.At(target=self.bot_account_id))
components.append(platform_message.Plain(text=" "))
components.append(platform_message.Plain(text=text))
chain = platform_message.MessageChain(components)
return platform_events.GroupMessage(
type="GroupMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
def create_image_message(
self,
url: str = "https://example.com/image.png",
text: str = "",
sender_id: typing.Union[int, str] = 12345,
is_group: bool = False,
group_id: typing.Union[int, str] = 99999,
) -> platform_events.MessageEvent:
"""Create an inbound image message event."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.Image(url=url))
chain = platform_message.MessageChain(components)
if is_group:
return self.create_group_message("", sender_id, group_id=group_id)
# Replace chain
else:
sender = platform_entities.Friend(id=sender_id, nickname="TestUser", remark=None)
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=chain,
time=1609459200,
)
# ============== Adapter Methods (Simulated) ==============
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain,
):
"""Simulate sending a message (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "send",
"target_type": target_type,
"target_id": target_id,
"message": message,
})
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""Simulate replying to a message (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_messages.append({
"type": "reply",
"source_type": message_source.type,
"source": message_source,
"message": message,
"quote_origin": quote_origin,
})
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message: dict,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""Simulate streaming reply (captures for assertions)."""
if self._raise_error:
raise self._raise_error
self._outbound_chunks.append({
"type": "reply_chunk",
"source_type": message_source.type,
"source": message_source,
"bot_message": bot_message,
"message": message,
"quote_origin": quote_origin,
"is_final": is_final,
})
async def is_stream_output_supported(self) -> bool:
"""Return whether streaming output is supported."""
return self._stream_output_supported
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable,
):
"""Register an event listener (stores for simulation)."""
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(callback)
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable,
):
"""Unregister an event listener."""
if event_type in self._listeners:
self._listeners[event_type].remove(callback)
async def run_async(self):
"""Simulate running the adapter (does nothing)."""
pass
async def kill(self) -> bool:
"""Simulate killing the adapter."""
return True
async def is_muted(self, group_id: int) -> bool:
"""Simulate checking mute status."""
return False
async def create_message_card(
self,
message_id: typing.Type[str, int],
event: platform_events.MessageEvent,
) -> bool:
"""Simulate creating a message card."""
return False
# ============== Simulation Helpers ==============
async def simulate_inbound_event(
self,
event: platform_events.Event,
):
"""Simulate receiving an inbound event by calling registered listeners."""
listeners = self._listeners.get(type(event), [])
for callback in listeners:
await callback(event, self)
def fake_platform(
bot_account_id: str = "test-bot",
stream_output_supported: bool = False,
) -> FakePlatform:
"""Create a FakePlatform instance."""
return FakePlatform(
bot_account_id=bot_account_id,
stream_output_supported=stream_output_supported,
)
def fake_platform_with_streaming() -> FakePlatform:
"""Create a FakePlatform that supports streaming output."""
return FakePlatform(stream_output_supported=True)
def fake_platform_with_failure() -> FakePlatform:
"""Create a FakePlatform that simulates send failure."""
return FakePlatform().send_failure()
# ============== Mock Adapter (for Query) ==============
def mock_platform_adapter(platform: FakePlatform = None) -> Mock:
"""Create a mock platform adapter using FakePlatform or a simple mock."""
if platform is None:
platform = FakePlatform()
adapter = Mock()
adapter.bot_account_id = platform.bot_account_id
adapter.reply_message = AsyncMock(side_effect=platform.reply_message)
adapter.reply_message_chunk = AsyncMock(side_effect=platform.reply_message_chunk)
adapter.send_message = AsyncMock(side_effect=platform.send_message)
adapter.is_stream_output_supported = AsyncMock(
return_value=platform._stream_output_supported
)
adapter._fake_platform = platform # Store for assertions
return adapter

View File

@@ -1,224 +0,0 @@
"""
Fake provider factory for tests.
Provides a deterministic fake provider that simulates LLM responses without real API calls.
"""
from __future__ import annotations
from unittest.mock import Mock
import typing
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class FakeProvider:
"""Deterministic fake provider for unit and integration tests.
Simulates various provider behaviors:
- Normal text response
- Streaming response
- Timeout error
- Auth error
- Rate-limit error
- Malformed response
Does not call real LLM vendors.
Does not require API keys.
"""
PONG_RESPONSE = "LANGBOT_FAKE_PONG"
def __init__(
self,
*,
default_response: str = "fake response",
streaming_chunks: list[str] = None,
raise_error: Exception = None,
captured_requests: list = None,
):
self._default_response = default_response
self._streaming_chunks = streaming_chunks or ["fake ", "response"]
self._raise_error = raise_error
self._captured_requests = captured_requests if captured_requests is not None else []
def returns(self, text: str) -> "FakeProvider":
"""Configure provider to return a specific text response."""
self._default_response = text
self._streaming_chunks = [text]
return self
def returns_streaming(self, chunks: list[str]) -> "FakeProvider":
"""Configure provider to return streaming chunks."""
self._streaming_chunks = chunks
self._default_response = "".join(chunks)
return self
def raises(self, error: Exception) -> "FakeProvider":
"""Configure provider to raise an error."""
self._raise_error = error
return self
def timeout(self) -> "FakeProvider":
"""Configure provider to simulate timeout."""
return self.raises(TimeoutError("Provider timeout"))
def auth_error(self) -> "FakeProvider":
"""Configure provider to simulate auth error."""
return self.raises(Exception("Invalid API key"))
def rate_limit(self) -> "FakeProvider":
"""Configure provider to simulate rate limit."""
return self.raises(Exception("Rate limit exceeded"))
def malformed(self) -> "FakeProvider":
"""Configure provider to simulate malformed response."""
self._default_response = None
return self
def get_captured_requests(self) -> list:
"""Get all captured request arguments for assertions."""
return self._captured_requests.copy()
def clear_captured_requests(self):
"""Clear captured requests."""
self._captured_requests.clear()
def _create_message(self, content: str) -> provider_message.Message:
"""Create a provider message from text content."""
return provider_message.Message(
role="assistant",
content=content,
)
def _create_chunk(
self,
content: str,
is_final: bool = False,
msg_sequence: int = 0,
) -> provider_message.MessageChunk:
"""Create a provider message chunk."""
return provider_message.MessageChunk(
role="assistant",
content=content,
is_final=is_final,
msg_sequence=msg_sequence,
)
async def invoke_llm(
self,
query,
model,
messages: list,
funcs: list,
extra_args: dict,
remove_think: bool = False,
) -> provider_message.Message:
"""Simulate non-streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
})
# Simulate error if configured
if self._raise_error:
raise self._raise_error
# Return response
if self._default_response is None:
# Malformed response
return provider_message.Message(role="assistant", content=None)
return self._create_message(self._default_response)
async def invoke_llm_stream(
self,
query,
model,
messages: list,
funcs: list,
extra_args: dict,
remove_think: bool = False,
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""Simulate streaming LLM invocation."""
# Capture request for assertions
self._captured_requests.append({
"query_id": query.query_id if query else None,
"model": model.model_entity.name if model and hasattr(model, 'model_entity') else None,
"messages": messages,
"funcs": funcs,
"extra_args": extra_args,
"streaming": True,
})
# Simulate error if configured
if self._raise_error:
raise self._raise_error
# Yield chunks
for i, chunk in enumerate(self._streaming_chunks):
is_final = (i == len(self._streaming_chunks) - 1)
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
def fake_provider(
default_response: str = "fake response",
) -> FakeProvider:
"""Create a FakeProvider with optional default response."""
return FakeProvider(default_response=default_response)
def fake_provider_pong() -> FakeProvider:
"""Create a FakeProvider that returns the pong response."""
return FakeProvider(default_response=FakeProvider.PONG_RESPONSE)
def fake_provider_timeout() -> FakeProvider:
"""Create a FakeProvider that simulates timeout."""
return FakeProvider().timeout()
def fake_provider_auth_error() -> FakeProvider:
"""Create a FakeProvider that simulates auth error."""
return FakeProvider().auth_error()
def fake_provider_rate_limit() -> FakeProvider:
"""Create a FakeProvider that simulates rate limit."""
return FakeProvider().rate_limit()
def fake_provider_malformed() -> FakeProvider:
"""Create a FakeProvider that simulates malformed response."""
return FakeProvider().malformed()
# ============== Mock Model Factory ==============
def fake_model(
*,
uuid: str = "test-model-uuid",
name: str = "test-model",
abilities: list[str] = None,
provider: FakeProvider = None,
) -> Mock:
"""Create a mock model with a fake provider."""
model = Mock()
model.model_entity = Mock()
model.model_entity.uuid = uuid
model.model_entity.name = name
model.model_entity.abilities = abilities or ["func_call", "vision"]
model.model_entity.extra_args = {}
# Attach fake provider
if provider is None:
provider = FakeProvider()
model.provider = provider
return model

View File

@@ -1,6 +0,0 @@
"""
Integration tests package.
These tests validate real system behavior with actual database/network resources.
Run with: uv run pytest tests/integration/ -m "not slow" -q
"""

View File

@@ -1,5 +0,0 @@
"""
API integration tests package.
Tests for HTTP API endpoints using Quart test client.
"""

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
import pytest
def dedupe_preregistered_groups() -> None:
"""Keep API integration route registration isolated across test modules."""
from langbot.pkg.api.http.controller import group
seen: set[tuple[str, str]] = set()
unique_groups = []
for group_cls in group.preregistered_groups:
key = (group_cls.name, group_cls.path)
if key in seen:
continue
seen.add(key)
unique_groups.append(group_cls)
group.preregistered_groups[:] = unique_groups
@pytest.fixture(scope='module')
def http_controller_cls(mock_circular_import_chain):
"""Import HTTPController under each module's circular-import isolation."""
from langbot.pkg.api.http.controller.main import HTTPController
dedupe_preregistered_groups()
return HTTPController

View File

@@ -1,253 +0,0 @@
"""
API integration tests for bot endpoints.
Tests real HTTP API behavior for bot management.
Run: uv run pytest tests/integration/api/test_bots.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.platform',
'langbot.pkg.api.http.controller.groups.platform.bots',
'langbot.pkg.api.http.controller.groups.platform.adapters',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.platform.bots as _bots # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_bot_app():
"""Create FakeApp with bot services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[
{
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
}
])
app.bot_service.get_runtime_bot_info = AsyncMock(return_value={
'uuid': 'test-bot-uuid',
'name': 'Test Bot',
'platform': 'telegram',
'pipeline_uuid': 'test-pipeline-uuid',
'webhook_url': 'https://example.com/webhook/test-bot-uuid',
})
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
app.bot_service.update_bot = AsyncMock(return_value={})
app.bot_service.delete_bot = AsyncMock()
app.bot_service.list_event_logs = AsyncMock(return_value=(
[{'uuid': 'log-1', 'message': 'test log'}],
1
))
app.bot_service.send_message = AsyncMock()
# Platform manager
app.platform_mgr = Mock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_bot_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_bot_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotEndpoints:
"""Tests for /api/v1/platform/bots endpoints."""
@pytest.mark.asyncio
async def test_get_bots_success(self, quart_test_client):
"""GET /api/v1/platform/bots returns bot list."""
response = await quart_test_client.get(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert 'bots' in data['data']
@pytest.mark.asyncio
async def test_create_bot_success(self, quart_test_client):
"""POST /api/v1/platform/bots creates new bot."""
response = await quart_test_client.post(
'/api/v1/platform/bots',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Bot', 'platform': 'telegram', 'pipeline_uuid': 'test-pipeline'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_get_single_bot_success(self, quart_test_client):
"""GET /api/v1/platform/bots/{uuid} returns bot with runtime info."""
response = await quart_test_client.get(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'bot' in data['data']
@pytest.mark.asyncio
async def test_update_bot_success(self, quart_test_client):
"""PUT /api/v1/platform/bots/{uuid} updates bot."""
response = await quart_test_client.put(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Bot'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_bot_success(self, quart_test_client):
"""DELETE /api/v1/platform/bots/{uuid} deletes bot."""
response = await quart_test_client.delete(
'/api/v1/platform/bots/test-bot-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotLogsEndpoint:
"""Tests for bot logs endpoint."""
@pytest.mark.asyncio
async def test_get_bot_logs_success(self, quart_test_client):
"""POST /api/v1/platform/bots/{uuid}/logs returns logs."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/logs',
headers={'Authorization': 'Bearer test_token'},
json={'from_index': -1, 'max_count': 10}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'logs' in data['data']
assert 'total_count' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestBotSendMessageEndpoint:
"""Tests for bot send message endpoint."""
@pytest.mark.asyncio
async def test_send_message_success(self, quart_test_client):
"""POST /api/v1/platform/bots/{uuid}/send_message sends message."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={
'target_type': 'person',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert data['data']['sent'] is True
@pytest.mark.asyncio
async def test_send_message_missing_target_type(self, quart_test_client):
"""POST send_message without target_type returns 400."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={'target_id': 'user123', 'message_chain': [{'type': 'text', 'text': 'Hello'}]}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1
@pytest.mark.asyncio
async def test_send_message_invalid_target_type(self, quart_test_client):
"""POST send_message with invalid target_type returns 400."""
response = await quart_test_client.post(
'/api/v1/platform/bots/test-bot-uuid/send_message',
headers={'Authorization': 'Bearer test_api_key'},
json={
'target_type': 'invalid',
'target_id': 'user123',
'message_chain': [{'type': 'text', 'text': 'Hello'}]
}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1

View File

@@ -1,300 +0,0 @@
"""
API integration tests for embed widget endpoints.
Tests real HTTP API behavior for embed widget functionality.
Run: uv run pytest tests/integration/api/test_embed.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.embed',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.pipelines.embed as _embed # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_embed_app():
"""Create FakeApp with embed widget services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Create mock web_page_bot with valid UUID format
mock_bot_entity = Mock()
mock_bot_entity.uuid = 'a1b2c3d4-5678-90ab-cdef-123456789abc'
mock_bot_entity.adapter = 'web_page_bot'
mock_bot_entity.enable = True
mock_bot_entity.use_pipeline_uuid = 'test-pipeline-uuid'
mock_bot_entity.name = 'Test Web Bot'
mock_bot_entity.adapter_config = {
'turnstile_secret_key': '',
'turnstile_site_key': '',
'language': 'en_US',
'bubble_icon': 'logo',
}
mock_runtime_bot = Mock()
mock_runtime_bot.bot_entity = mock_bot_entity
# Platform manager with bots
app.platform_mgr = Mock()
app.platform_mgr.bots = [mock_runtime_bot]
# WebSocket proxy bot with adapter
mock_websocket_adapter = Mock()
mock_websocket_adapter.get_websocket_messages = Mock(return_value=[
{'id': 'msg-1', 'content': 'test message'}
])
mock_websocket_adapter.reset_session = Mock()
mock_websocket_adapter.handle_websocket_message = AsyncMock()
mock_ws_proxy_bot = Mock()
mock_ws_proxy_bot.adapter = mock_websocket_adapter
app.platform_mgr.websocket_proxy_bot = mock_ws_proxy_bot
# Monitoring service for feedback
app.monitoring_service = Mock()
app.monitoring_service.record_feedback = AsyncMock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_embed_app, http_controller_cls):
"""Create Quart test client (module scope)."""
controller = http_controller_cls(fake_embed_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedWidgetEndpoint:
"""Tests for widget.js endpoint."""
@pytest.mark.asyncio
async def test_get_widget_js_success(self, quart_test_client):
"""GET /api/v1/embed/{bot_uuid}/widget.js returns JS."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/widget.js'
)
assert response.status_code == 200
assert 'javascript' in response.content_type
@pytest.mark.asyncio
async def test_get_widget_js_invalid_uuid(self, quart_test_client):
"""GET widget.js with invalid UUID returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/invalid-uuid/widget.js'
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_get_widget_js_bot_not_found(self, quart_test_client):
"""GET widget.js for non-existent bot returns 404."""
response = await quart_test_client.get(
'/api/v1/embed/00000000-0000-0000-0000-000000000000/widget.js'
)
assert response.status_code == 404
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedLogoEndpoint:
"""Tests for logo endpoint."""
@pytest.mark.asyncio
async def test_get_logo_success(self, quart_test_client):
"""GET /api/v1/embed/logo returns image."""
response = await quart_test_client.get('/api/v1/embed/logo')
assert response.status_code == 200
assert 'image/webp' in response.content_type
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedTurnstileVerifyEndpoint:
"""Tests for Turnstile verification endpoint."""
@pytest.mark.asyncio
async def test_turnstile_verify_no_secret(self, quart_test_client):
"""POST turnstile verify without secret returns dummy token."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={'token': 'test-token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'token' in data['data']
@pytest.mark.asyncio
async def test_turnstile_verify_invalid_uuid(self, quart_test_client):
"""POST turnstile verify with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/turnstile/verify',
json={'token': 'test-token'}
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_turnstile_verify_missing_token(self, quart_test_client):
"""POST turnstile verify without token returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/turnstile/verify',
json={}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedMessagesEndpoint:
"""Tests for messages endpoint."""
@pytest.mark.asyncio
async def test_get_messages_person_success(self, quart_test_client):
"""GET messages/person returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'messages' in data['data']
@pytest.mark.asyncio
async def test_get_messages_group_success(self, quart_test_client):
"""GET messages/group returns messages."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/group',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_messages_invalid_session_type(self, quart_test_client):
"""GET messages with invalid session_type returns 400."""
response = await quart_test_client.get(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/messages/invalid',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedResetEndpoint:
"""Tests for session reset endpoint."""
@pytest.mark.asyncio
async def test_reset_session_person_success(self, quart_test_client):
"""POST reset/person resets session."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_reset_session_invalid_uuid(self, quart_test_client):
"""POST reset with invalid UUID returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/invalid-uuid/reset/person',
headers={'Authorization': 'Bearer 1234567890.dummy'}
)
assert response.status_code == 400
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbedFeedbackEndpoint:
"""Tests for feedback submission endpoint."""
@pytest.mark.asyncio
async def test_submit_feedback_like(self, quart_test_client):
"""POST feedback with type=1 (like) succeeds."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 1}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'feedback_id' in data['data']
@pytest.mark.asyncio
async def test_submit_feedback_dislike(self, quart_test_client):
"""POST feedback with type=2 (dislike) succeeds."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 2}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_submit_feedback_invalid_type(self, quart_test_client):
"""POST feedback with invalid type returns 400."""
response = await quart_test_client.post(
'/api/v1/embed/a1b2c3d4-5678-90ab-cdef-123456789abc/feedback',
headers={'Authorization': 'Bearer 1234567890.dummy'},
json={'message_id': 'msg-123', 'feedback_type': 99}
)
assert response.status_code == 400

View File

@@ -1,259 +0,0 @@
"""
API integration tests for knowledge base endpoints.
Tests real HTTP API behavior for knowledge base management.
Run: uv run pytest tests/integration/api/test_knowledge.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.knowledge',
'langbot.pkg.api.http.controller.groups.knowledge.base',
'langbot.pkg.api.http.controller.groups.knowledge.engines',
'langbot.pkg.api.http.controller.groups.knowledge.parsers',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.knowledge.base as _knowledge # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_knowledge_app():
"""Create FakeApp with knowledge services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Knowledge service
app.knowledge_service = Mock()
app.knowledge_service.get_knowledge_bases = AsyncMock(return_value=[
{
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
}
])
app.knowledge_service.get_knowledge_base = AsyncMock(return_value={
'uuid': 'test-kb-uuid',
'name': 'Test Knowledge Base',
'description': 'Test KB description',
'engine_plugin_id': 'test/engine',
})
app.knowledge_service.create_knowledge_base = AsyncMock(return_value={'uuid': 'new-kb-uuid'})
app.knowledge_service.update_knowledge_base = AsyncMock(return_value={})
app.knowledge_service.delete_knowledge_base = AsyncMock()
app.knowledge_service.get_files_by_knowledge_base = AsyncMock(return_value=[
{'uuid': 'test-file-uuid', 'filename': 'test.pdf'}
])
app.knowledge_service.store_file = AsyncMock(return_value={'task_id': 'test-task-id'})
app.knowledge_service.delete_file = AsyncMock()
app.knowledge_service.retrieve_knowledge_base = AsyncMock(return_value=[
{'content': 'test result', 'score': 0.95}
])
# RAG manager
app.rag_mgr = Mock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_knowledge_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_knowledge_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseEndpoints:
"""Tests for /api/v1/knowledge/bases endpoints."""
@pytest.mark.asyncio
async def test_get_knowledge_bases_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases returns knowledge base list."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert 'bases' in data['data']
@pytest.mark.asyncio
async def test_create_knowledge_base_success(self, quart_test_client):
"""POST /api/v1/knowledge/bases creates new knowledge base."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New KB', 'engine_plugin_id': 'test/engine'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_get_single_knowledge_base_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid} returns knowledge base."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'base' in data['data']
@pytest.mark.asyncio
async def test_update_knowledge_base_success(self, quart_test_client):
"""PUT /api/v1/knowledge/bases/{uuid} updates knowledge base."""
response = await quart_test_client.put(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated KB'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_knowledge_base_success(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid} deletes knowledge base."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseFilesEndpoints:
"""Tests for knowledge base files endpoints."""
@pytest.mark.asyncio
async def test_get_files_success(self, quart_test_client):
"""GET /api/v1/knowledge/bases/{uuid}/files returns files."""
response = await quart_test_client.get(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'files' in data['data']
@pytest.mark.asyncio
async def test_add_file_to_knowledge_base(self, quart_test_client):
"""POST /api/v1/knowledge/bases/{uuid}/files adds file."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/files',
headers={'Authorization': 'Bearer test_token'},
json={'file_id': 'test-file-id', 'parser_plugin_id': 'test/parser'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'task_id' in data['data']
@pytest.mark.asyncio
async def test_delete_file_from_knowledge_base(self, quart_test_client):
"""DELETE /api/v1/knowledge/bases/{uuid}/files/{file_id}."""
response = await quart_test_client.delete(
'/api/v1/knowledge/bases/test-kb-uuid/files/test-file-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestKnowledgeBaseRetrieveEndpoint:
"""Tests for knowledge base retrieval endpoint."""
@pytest.mark.asyncio
async def test_retrieve_knowledge_success(self, quart_test_client):
"""POST /api/v1/knowledge/bases/{uuid}/retrieve."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={'query': 'test query', 'retrieval_settings': {'top_k': 5}}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'results' in data['data']
@pytest.mark.asyncio
async def test_retrieve_without_query_returns_error(self, quart_test_client):
"""POST retrieve without query returns 400."""
response = await quart_test_client.post(
'/api/v1/knowledge/bases/test-kb-uuid/retrieve',
headers={'Authorization': 'Bearer test_token'},
json={}
)
assert response.status_code == 400
data = await response.get_json()
assert data['code'] == -1

View File

@@ -1,330 +0,0 @@
"""
API integration tests for monitoring endpoints.
Tests real HTTP API behavior for monitoring data retrieval.
Run: uv run pytest tests/integration/api/test_monitoring.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.monitoring',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.monitoring as _monitoring # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_monitoring_app():
"""Create FakeApp with monitoring services (module scope)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services - USER_TOKEN auth requires jwt verification AND get_user_by_email
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
# Monitoring service
app.monitoring_service = Mock()
app.monitoring_service.get_overview_metrics = AsyncMock(return_value={
'total_messages': 100,
'total_llm_calls': 50,
'total_sessions': 20,
'active_sessions': 5,
'total_errors': 2,
})
app.monitoring_service.get_messages = AsyncMock(return_value=(
[{'id': 'msg-1', 'content': 'test'}], 100
))
app.monitoring_service.get_llm_calls = AsyncMock(return_value=(
[{'id': 'llm-1'}], 50
))
app.monitoring_service.get_embedding_calls = AsyncMock(return_value=(
[{'id': 'emb-1'}], 10
))
app.monitoring_service.get_sessions = AsyncMock(return_value=(
[{'session_id': 'sess-1'}], 20
))
app.monitoring_service.get_errors = AsyncMock(return_value=(
[{'id': 'err-1'}], 2
))
app.monitoring_service.get_session_analysis = AsyncMock(return_value={
'found': True,
'session_id': 'sess-1',
})
app.monitoring_service.get_message_details = AsyncMock(return_value={
'found': True,
'message_id': 'msg-1',
})
app.monitoring_service.get_feedback_stats = AsyncMock(return_value={'like_count': 10})
app.monitoring_service.get_feedback_list = AsyncMock(return_value=(
[{'feedback_id': 'fb-1'}], 12
))
app.monitoring_service.export_messages = AsyncMock(return_value=[{'id': 'msg-1'}])
app.monitoring_service.export_llm_calls = AsyncMock(return_value=[{'id': 'llm-1'}])
app.monitoring_service.export_errors = AsyncMock(return_value=[{'id': 'err-1'}])
app.monitoring_service.export_sessions = AsyncMock(return_value=[{'session_id': 'sess-1'}])
app.monitoring_service.export_feedback = AsyncMock(return_value=[{'id': 'fb-1'}])
app.monitoring_service.export_embedding_calls = AsyncMock(return_value=[{'id': 'emb-1'}])
app.monitoring_service._escape_csv_field = Mock(return_value='escaped')
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_monitoring_app, http_controller_cls):
"""Create Quart test client (module scope)."""
controller = http_controller_cls(fake_monitoring_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringOverviewEndpoint:
"""Tests for /api/v1/monitoring/overview endpoint."""
@pytest.mark.asyncio
async def test_get_overview_success(self, quart_test_client):
"""GET /api/v1/monitoring/overview returns metrics."""
response = await quart_test_client.get(
'/api/v1/monitoring/overview',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringMessagesEndpoint:
"""Tests for /api/v1/monitoring/messages endpoint."""
@pytest.mark.asyncio
async def test_get_messages_success(self, quart_test_client):
"""GET /api/v1/monitoring/messages returns message list."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'messages' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringLLMCallsEndpoint:
"""Tests for /api/v1/monitoring/llm-calls endpoint."""
@pytest.mark.asyncio
async def test_get_llm_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/llm-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/llm-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringEmbeddingCallsEndpoint:
"""Tests for /api/v1/monitoring/embedding-calls endpoint."""
@pytest.mark.asyncio
async def test_get_embedding_calls_success(self, quart_test_client):
"""GET /api/v1/monitoring/embedding-calls."""
response = await quart_test_client.get(
'/api/v1/monitoring/embedding-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringSessionsEndpoint:
"""Tests for /api/v1/monitoring/sessions endpoint."""
@pytest.mark.asyncio
async def test_get_sessions_success(self, quart_test_client):
"""GET /api/v1/monitoring/sessions."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringErrorsEndpoint:
"""Tests for /api/v1/monitoring/errors endpoint."""
@pytest.mark.asyncio
async def test_get_errors_success(self, quart_test_client):
"""GET /api/v1/monitoring/errors."""
response = await quart_test_client.get(
'/api/v1/monitoring/errors',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringAllDataEndpoint:
"""Tests for /api/v1/monitoring/data endpoint."""
@pytest.mark.asyncio
async def test_get_all_data_success(self, quart_test_client):
"""GET /api/v1/monitoring/data returns all data."""
response = await quart_test_client.get(
'/api/v1/monitoring/data',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert 'overview' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringDetailsEndpoints:
"""Tests for detail endpoints."""
@pytest.mark.asyncio
async def test_get_session_analysis(self, quart_test_client):
"""GET /api/v1/monitoring/sessions/{id}/analysis."""
response = await quart_test_client.get(
'/api/v1/monitoring/sessions/sess-1/analysis',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_message_details(self, quart_test_client):
"""GET /api/v1/monitoring/messages/{id}/details."""
response = await quart_test_client.get(
'/api/v1/monitoring/messages/msg-1/details',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringFeedbackEndpoints:
"""Tests for feedback endpoints."""
@pytest.mark.asyncio
async def test_get_feedback_stats(self, quart_test_client):
"""GET /api/v1/monitoring/feedback/stats."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback/stats',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_feedback_list(self, quart_test_client):
"""GET /api/v1/monitoring/feedback."""
response = await quart_test_client.get(
'/api/v1/monitoring/feedback',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestMonitoringExportEndpoint:
"""Tests for /api/v1/monitoring/export endpoint."""
@pytest.mark.asyncio
async def test_export_messages(self, quart_test_client):
"""GET export?type=messages returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=messages',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
assert 'text/csv' in response.content_type
@pytest.mark.asyncio
async def test_export_llm_calls(self, quart_test_client):
"""GET export?type=llm-calls returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=llm-calls',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_export_sessions(self, quart_test_client):
"""GET export?type=sessions returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=sessions',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_export_feedback(self, quart_test_client):
"""GET export?type=feedback returns CSV."""
response = await quart_test_client.get(
'/api/v1/monitoring/export?type=feedback',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200

View File

@@ -1,273 +0,0 @@
"""
API integration tests for pipeline endpoints.
Tests real HTTP API behavior using Quart test client with mocked services.
Extends test_smoke.py coverage for pipeline-related endpoints.
Run: uv run pytest tests/integration/api/test_pipelines.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.pipelines',
'langbot.pkg.api.http.controller.groups.pipelines.embed',
'langbot.pkg.api.http.controller.groups.pipelines.websocket_chat',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
# Import groups after mocking to populate preregistered_groups
import langbot.pkg.api.http.controller.groups.pipelines.pipelines as _pipelines # noqa: E402, F401
yield
# ============== FAKE APPLICATION WITH PIPELINE SERVICES ==============
@pytest.fixture(scope='module')
def fake_pipeline_app():
"""Create FakeApp with pipeline-specific services (module scope for reuse)."""
app = FakeApp()
# Pipeline config
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Pipeline service
app.pipeline_service = Mock()
app.pipeline_service.get_pipeline_metadata = AsyncMock(return_value=[
{'name': 'trigger', 'stages': []},
{'name': 'ai', 'stages': []},
])
app.pipeline_service.get_pipelines = AsyncMock(return_value=[
{
'uuid': 'test-pipeline-uuid',
'name': 'Test Pipeline',
'description': 'Test description',
'created_at': '2024-01-01T00:00:00',
'updated_at': '2024-01-01T00:00:00',
'is_default': False,
}
])
app.pipeline_service.get_pipeline = AsyncMock(return_value={
'uuid': 'test-pipeline-uuid',
'name': 'Test Pipeline',
'config': {},
})
app.pipeline_service.create_pipeline = AsyncMock(return_value={'uuid': 'new-pipeline-uuid'})
app.pipeline_service.update_pipeline = AsyncMock(return_value={})
app.pipeline_service.delete_pipeline = AsyncMock()
app.pipeline_service.copy_pipeline = AsyncMock(return_value={'uuid': 'copied-pipeline-uuid'})
# Bot service
app.bot_service = Mock()
app.bot_service.get_bots = AsyncMock(return_value=[])
app.bot_service.create_bot = AsyncMock(return_value={'uuid': 'new-bot-uuid'})
# MCP service (for extensions endpoint)
app.mcp_service = Mock()
app.mcp_service.get_mcp_servers = AsyncMock(return_value=[])
# Plugin connector (for extensions endpoint)
app.plugin_connector.list_plugins = AsyncMock(return_value=[])
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_pipeline_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_pipeline_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
# ============== PIPELINE ENDPOINT TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineMetadataEndpoint:
"""Tests for /api/v1/pipelines/_/metadata endpoint."""
@pytest.mark.asyncio
async def test_get_pipeline_metadata_success(self, quart_test_client):
"""GET /api/v1/pipelines/_/metadata returns metadata list."""
response = await quart_test_client.get(
'/api/v1/pipelines/_/metadata',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
assert isinstance(data['data'], dict)
@pytest.mark.asyncio
async def test_get_pipeline_metadata_requires_auth(self, quart_test_client):
"""Pipeline metadata endpoint requires authentication."""
response = await quart_test_client.get('/api/v1/pipelines/_/metadata')
assert response.status_code == 401
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelinesListEndpoint:
"""Tests for /api/v1/pipelines endpoint."""
@pytest.mark.asyncio
async def test_get_pipelines_success(self, quart_test_client):
"""GET /api/v1/pipelines returns pipeline list."""
response = await quart_test_client.get(
'/api/v1/pipelines',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_get_pipelines_with_sort_param(self, quart_test_client):
"""GET pipelines with sort parameter."""
response = await quart_test_client.get(
'/api/v1/pipelines?sort_by=created_at&sort_order=DESC',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelinesCRUDEndpoints:
"""Tests for pipeline CRUD operations."""
@pytest.mark.asyncio
async def test_get_single_pipeline_success(self, quart_test_client):
"""GET /api/v1/pipelines/{uuid} returns pipeline."""
response = await quart_test_client.get(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_create_pipeline_success(self, quart_test_client):
"""POST /api/v1/pipelines creates new pipeline."""
response = await quart_test_client.post(
'/api/v1/pipelines',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Pipeline', 'config': {}}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_update_pipeline_success(self, quart_test_client):
"""PUT /api/v1/pipelines/{uuid} updates pipeline."""
response = await quart_test_client.put(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Pipeline'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_pipeline_success(self, quart_test_client):
"""DELETE /api/v1/pipelines/{uuid} deletes pipeline."""
response = await quart_test_client.delete(
'/api/v1/pipelines/test-pipeline-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_copy_pipeline_success(self, quart_test_client):
"""POST /api/v1/pipelines/{uuid}/copy copies pipeline."""
response = await quart_test_client.post(
'/api/v1/pipelines/test-pipeline-uuid/copy',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineExtensionsEndpoint:
"""Tests for pipeline extensions."""
@pytest.mark.asyncio
async def test_get_extensions(self, quart_test_client):
"""GET /api/v1/pipelines/{uuid}/extensions."""
response = await quart_test_client.get(
'/api/v1/pipelines/test-pipeline-uuid/extensions',
headers={'Authorization': 'Bearer test_token'}
)
# Should return 200 if pipeline found
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0

View File

@@ -1,347 +0,0 @@
"""
API integration tests for provider/model endpoints.
Tests real HTTP API behavior for provider and model management.
Run: uv run pytest tests/integration/api/test_providers.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""Break circular import chain for API controller."""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.provider',
'langbot.pkg.api.http.controller.groups.provider.providers',
'langbot.pkg.api.http.controller.groups.provider.models',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
import langbot.pkg.api.http.controller.groups.provider.providers as _providers # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.provider.models as _models # noqa: E402, F401
yield
@pytest.fixture(scope='module')
def fake_provider_app():
"""Create FakeApp with provider/model services (module scope for reuse)."""
app = FakeApp()
app.instance_config.data.update({
'api': {'port': 5300},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# Auth services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=True)
app.user_service.verify_jwt_token = AsyncMock(return_value='test@example.com')
app.user_service.get_user_by_email = AsyncMock(return_value=Mock(email='test@example.com'))
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
# Provider service
app.provider_service = Mock()
app.provider_service.get_providers = AsyncMock(return_value=[
{'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'}
])
app.provider_service.get_provider = AsyncMock(return_value={
'uuid': 'test-provider-uuid', 'name': 'OpenAI', 'requester': 'chatcmpl'
})
app.provider_service.create_provider = AsyncMock(return_value='new-provider-uuid')
app.provider_service.update_provider = AsyncMock(return_value={})
app.provider_service.delete_provider = AsyncMock()
app.provider_service.get_provider_model_counts = AsyncMock(return_value={
'llm_count': 2, 'embedding_count': 1, 'rerank_count': 0
})
# LLM model service
app.llm_model_service = Mock()
app.llm_model_service.get_llm_models = AsyncMock(return_value=[
{'uuid': 'test-model-uuid', 'name': 'gpt-4'}
])
app.llm_model_service.get_llm_model = AsyncMock(return_value={
'uuid': 'test-model-uuid', 'name': 'gpt-4'
})
app.llm_model_service.create_llm_model = AsyncMock(return_value={'uuid': 'new-model-uuid'})
app.llm_model_service.update_llm_model = AsyncMock(return_value={})
app.llm_model_service.delete_llm_model = AsyncMock()
# Embedding model service
app.embedding_models_service = Mock()
app.embedding_models_service.get_embedding_models = AsyncMock(return_value=[])
app.embedding_models_service.create_embedding_model = AsyncMock(return_value={'uuid': 'new-embedding-uuid'})
# Rerank model service
app.rerank_models_service = Mock()
app.rerank_models_service.get_rerank_models = AsyncMock(return_value=[])
app.rerank_models_service.create_rerank_model = AsyncMock(return_value={'uuid': 'new-rerank-uuid'})
# Model manager
app.model_mgr = Mock()
app.model_mgr.load_provider = AsyncMock()
app.model_mgr.unload_provider = AsyncMock()
return app
@pytest.fixture(scope='module')
async def quart_test_client(fake_provider_app, http_controller_cls):
"""Create Quart test client (module scope to avoid route re-registration)."""
controller = http_controller_cls(fake_provider_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProviderEndpoints:
"""Tests for /api/v1/provider endpoints."""
@pytest.mark.asyncio
async def test_get_providers_success(self, quart_test_client):
"""GET /api/v1/provider/providers returns provider list with complete structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
# Verify response structure completeness
providers = data['data']['providers']
assert isinstance(providers, list)
assert len(providers) == 1
# Verify required fields in provider object
provider = providers[0]
assert 'uuid' in provider
assert 'name' in provider
assert 'requester' in provider
assert provider['uuid'] == 'test-provider-uuid'
assert provider['name'] == 'OpenAI'
@pytest.mark.asyncio
async def test_get_single_provider_success(self, quart_test_client):
"""GET /api/v1/provider/providers/{uuid} returns complete provider structure."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Verify response structure
provider = data['data']['provider']
assert 'uuid' in provider
assert 'name' in provider
assert 'requester' in provider
assert provider['uuid'] == 'test-provider-uuid'
@pytest.mark.asyncio
async def test_create_provider_success(self, quart_test_client):
"""POST /api/v1/provider/providers creates new provider with uuid returned."""
response = await quart_test_client.post(
'/api/v1/provider/providers',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Provider', 'requester': 'chatcmpl'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Verify uuid is present and matches expected
assert 'data' in data
assert 'uuid' in data['data']
assert data['data']['uuid'] == 'new-provider-uuid'
@pytest.mark.asyncio
async def test_update_provider_success(self, quart_test_client):
"""PUT /api/v1/provider/providers/{uuid} updates provider."""
response = await quart_test_client.put(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'Updated Provider'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_delete_provider_success(self, quart_test_client):
"""DELETE /api/v1/provider/providers/{uuid} deletes provider."""
response = await quart_test_client.delete(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_get_provider_includes_model_counts(self, quart_test_client):
"""GET provider response includes model counts."""
response = await quart_test_client.get(
'/api/v1/provider/providers/test-provider-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
# Model counts are embedded in provider response
provider_data = data['data']['provider']
assert 'llm_count' in provider_data
assert 'embedding_count' in provider_data
assert 'rerank_count' in provider_data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestModelEndpoints:
"""Tests for /api/v1/provider/models endpoints."""
@pytest.mark.asyncio
async def test_get_llm_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'data' in data
@pytest.mark.asyncio
async def test_get_single_llm_model_success(self, quart_test_client):
"""GET /api/v1/provider/models/llm/{uuid} returns model."""
response = await quart_test_client.get(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
@pytest.mark.asyncio
async def test_create_llm_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/llm creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/llm',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.asyncio
async def test_delete_llm_model_success(self, quart_test_client):
"""DELETE /api/v1/provider/models/llm/{uuid} deletes model."""
response = await quart_test_client.delete(
'/api/v1/provider/models/llm/test-model-uuid',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestEmbeddingModelEndpoints:
"""Tests for /api/v1/provider/models/embedding endpoints."""
@pytest.mark.asyncio
async def test_get_embedding_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/embedding returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'models' in data['data']
@pytest.mark.asyncio
async def test_create_embedding_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/embedding creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/embedding',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Embedding Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRerankModelEndpoints:
"""Tests for /api/v1/provider/models/rerank endpoints."""
@pytest.mark.asyncio
async def test_get_rerank_models_success(self, quart_test_client):
"""GET /api/v1/provider/models/rerank returns model list."""
response = await quart_test_client.get(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'models' in data['data']
@pytest.mark.asyncio
async def test_create_rerank_model_success(self, quart_test_client):
"""POST /api/v1/provider/models/rerank creates new model."""
response = await quart_test_client.post(
'/api/v1/provider/models/rerank',
headers={'Authorization': 'Bearer test_token'},
json={'name': 'New Rerank Model', 'provider_uuid': 'test-provider-uuid'}
)
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert 'uuid' in data['data']

View File

@@ -1,345 +0,0 @@
"""
API smoke integration tests.
Tests real HTTP API behavior using Quart test client.
Validates controller/service/routing wiring without real provider/platform.
Run: uv run pytest tests/integration/api/test_smoke.py -q
"""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, Mock
from tests.factories import FakeApp
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain for API controller using isolated_sys_modules.
Chain: http_controller → groups/plugins → core.app → pipeline entities
We need to mock core.app to prevent the circular chain when importing HTTPController.
But we must allow groups to be imported to populate preregistered_groups.
"""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
# Mock core.app with minimal Application that groups can reference
class FakeMinimalApplication:
pass
mock_app = MagicMock()
mock_app.Application = FakeMinimalApplication
# Mock core.entities with proper Enum
mock_entities = MagicMock()
mock_entities.LifecycleControlScope = MockLifecycleControlScope
# Modules to clear (force re-import after mocking)
clear = [
'langbot.pkg.api.http.controller.group',
'langbot.pkg.api.http.controller.groups',
'langbot.pkg.api.http.controller.groups.system',
'langbot.pkg.api.http.controller.groups.user',
'langbot.pkg.api.http.controller.main',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.app': mock_app,
'langbot.pkg.core.entities': mock_entities,
},
clear=clear,
):
# Import groups after mocking core.app/core.entities
import langbot.pkg.api.http.controller.group as _group_module # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.system as _system_group # noqa: E402, F401
import langbot.pkg.api.http.controller.groups.user as _user_group # noqa: E402, F401
yield
# ============== FAKE APPLICATION FOR API TESTS ==============
@pytest.fixture
def fake_api_app():
"""
Create minimal FakeApp for API smoke tests with all required services.
Uses tests.factories.FakeApp as base and adds API-specific services.
"""
app = FakeApp()
# API-specific config
app.instance_config.data.update({
'api': {'port': 5300},
'plugin': {'enable_marketplace': True},
'space': {'url': 'https://space.langbot.app'},
'system': {'allow_modify_login_info': True, 'limitation': {}},
})
# API-specific services
app.user_service = Mock()
app.user_service.is_initialized = AsyncMock(return_value=False)
app.user_service.authenticate = AsyncMock(return_value='fake_token')
app.user_service.create_user = AsyncMock()
app.user_service.verify_jwt_token = AsyncMock(side_effect=ValueError('Invalid token'))
app.user_service.get_user_by_email = AsyncMock(return_value=Mock())
app.user_service.generate_jwt_token = AsyncMock(return_value='fake_token')
app.apikey_service = Mock()
app.apikey_service.verify_api_key = AsyncMock(return_value=True)
app.maintenance_service = Mock()
app.maintenance_service.get_storage_analysis = AsyncMock(return_value={})
app.plugin_connector.is_enable_plugin = False
app.plugin_connector.ping_plugin_runtime = AsyncMock()
app.task_mgr.get_tasks_dict = Mock(return_value={'tasks': []})
app.task_mgr.get_task_by_id = Mock(return_value=None)
# Required by controller groups
app.model_mgr = Mock()
app.platform_mgr = Mock()
app.pipeline_pool = Mock()
app.pipeline_mgr = Mock()
return app
# ============== QUART TEST CLIENT FIXTURE ==============
@pytest.fixture
async def quart_test_client(fake_api_app, http_controller_cls):
"""
Create Quart test client with real HTTPController and route registration.
Requires mock_circular_import_chain fixture to run first (usefixtures).
"""
controller = http_controller_cls(fake_api_app)
await controller.initialize()
client = controller.quart_app.test_client()
yield client
# ============== API SMOKE TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestHealthEndpoint:
"""Tests for /healthz endpoint - simplest smoke test."""
@pytest.mark.asyncio
async def test_healthz_returns_ok(self, quart_test_client):
"""
/healthz endpoint returns {'code': 0, 'msg': 'ok'}.
This tests:
- HTTPController instantiation
- Quart app creation
- Route registration
- Basic response handling
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
data = await response.get_json()
assert data == {'code': 0, 'msg': 'ok'}
@pytest.mark.asyncio
async def test_healthz_no_auth_required(self, quart_test_client):
"""
/healthz doesn't require authentication.
Tests that AuthType.NONE endpoints work without headers.
"""
response = await quart_test_client.get('/healthz')
assert response.status_code == 200
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestSystemEndpoint:
"""Tests for /api/v1/system endpoints."""
@pytest.mark.asyncio
async def test_system_info_no_auth(self, quart_test_client):
"""
/api/v1/system/info returns system information without auth.
AuthType.NONE endpoint.
"""
response = await quart_test_client.get('/api/v1/system/info')
assert response.status_code == 200
data = await response.get_json()
# Verify response structure
assert data['code'] == 0
assert data['msg'] == 'ok'
assert 'data' in data
# Verify expected fields
system_data = data['data']
assert 'version' in system_data
assert 'debug' in system_data
assert 'edition' in system_data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProtectedEndpoints:
"""Tests for authentication/authorization behavior."""
@pytest.mark.asyncio
async def test_protected_endpoint_rejects_no_token(self, quart_test_client):
"""
Protected endpoint (USER_TOKEN) returns 401 without auth.
Tests that AuthType.USER_TOKEN properly rejects unauthorized requests.
"""
# /api/v1/user/check-token requires USER_TOKEN
response = await quart_test_client.get('/api/v1/user/check-token')
assert response.status_code == 401
data = await response.get_json()
# Verify error response structure
assert data['code'] == -1
assert 'msg' in data
@pytest.mark.asyncio
async def test_protected_endpoint_with_invalid_token(self, quart_test_client):
"""
Protected endpoint returns 401 with invalid token.
"""
response = await quart_test_client.get(
'/api/v1/user/check-token',
headers={'Authorization': 'Bearer invalid_token'}
)
assert response.status_code == 401
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestInvalidPayload:
"""Tests for error handling with invalid payloads."""
@pytest.mark.asyncio
async def test_missing_json_body(self, quart_test_client):
"""
POST endpoint without JSON body handles gracefully.
"""
# /api/v1/user/auth expects JSON with 'user' and 'password'
response = await quart_test_client.post('/api/v1/user/auth')
# Should return error (500, 400, or 401) with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
# Verify error response has expected structure
assert 'code' in data
assert 'msg' in data
@pytest.mark.asyncio
async def test_invalid_json_structure(self, quart_test_client):
"""
POST with wrong JSON structure returns stable error.
"""
response = await quart_test_client.post(
'/api/v1/user/auth',
json={'wrong_field': 'value'}
)
# Should return error with stable JSON structure
assert response.status_code in (400, 500, 401)
data = await response.get_json()
assert 'code' in data
assert 'msg' in data
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestUserInitEndpoint:
"""Tests for /api/v1/user/init endpoint."""
@pytest.mark.asyncio
async def test_user_init_get_returns_not_initialized(self, quart_test_client):
"""
GET /api/v1/user/init returns initialized status.
Uses fake user_service.is_initialized() = False.
"""
response = await quart_test_client.get('/api/v1/user/init')
assert response.status_code == 200
data = await response.get_json()
assert data['code'] == 0
assert data['msg'] == 'ok'
assert data['data']['initialized'] is False
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRealImports:
"""Tests that verify real production code is imported."""
def test_http_controller_real_import(self):
"""
Verify HTTPController is real production class, not mock.
"""
from langbot.pkg.api.http.controller.main import HTTPController
assert HTTPController.__name__ == 'HTTPController'
assert hasattr(HTTPController, 'initialize')
assert hasattr(HTTPController, 'register_routes')
def test_group_real_import(self):
"""
Verify RouterGroup and AuthType are real production classes.
"""
from langbot.pkg.api.http.controller.group import RouterGroup, AuthType, preregistered_groups
assert RouterGroup.__name__ == 'RouterGroup'
assert hasattr(AuthType, 'NONE')
assert hasattr(AuthType, 'USER_TOKEN')
assert isinstance(preregistered_groups, list)
def test_system_group_registered(self):
"""
Verify SystemRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find system group
system_group = None
for g in preregistered_groups:
if g.name == 'system':
system_group = g
break
assert system_group is not None
assert system_group.path == '/api/v1/system'
def test_user_group_registered(self):
"""
Verify UserRouterGroup is registered in preregistered_groups.
"""
from langbot.pkg.api.http.controller.group import preregistered_groups
# Find user group
user_group = None
for g in preregistered_groups:
if g.name == 'user':
user_group = g
break
assert user_group is not None
assert user_group.path == '/api/v1/user'

View File

@@ -1,5 +0,0 @@
"""
Persistence integration tests package.
Tests for database migrations and storage behavior.
"""

View File

@@ -1,251 +0,0 @@
"""
SQLite migration integration tests.
Tests real Alembic migration behavior using temporary SQLite databases.
Validates the migration workflow from .github/workflows/test-migrations.yml.
Run: uv run pytest tests/integration/persistence/test_migrations.py -q
"""
from __future__ import annotations
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import (
run_alembic_upgrade,
run_alembic_stamp,
get_alembic_current,
)
pytestmark = pytest.mark.integration
@pytest.fixture
def sqlite_db_url(tmp_path):
"""Create SQLite URL with temporary database file."""
db_file = tmp_path / "test_migrations.db"
return f"sqlite+aiosqlite:///{db_file}"
@pytest.fixture
async def sqlite_engine(sqlite_db_url):
"""Create async SQLite engine."""
engine = create_async_engine(sqlite_db_url)
yield engine
await engine.dispose()
class TestSQLiteMigrationBaseline:
"""Tests for baseline stamp workflow."""
@pytest.mark.asyncio
async def test_baseline_stamp_sets_revision(self, sqlite_engine):
"""
Stamp baseline on existing tables sets correct revision.
Workflow:
1. Create tables via Base.metadata.create_all
2. Stamp with '0001_baseline'
3. Verify current revision is '0001_baseline'
"""
# Create all tables (simulates existing DB created by ORM)
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_baseline_stamp_on_empty_db(self, sqlite_engine):
"""
Stamp on empty database (no tables) still sets revision.
This is an edge case - stamping without tables.
"""
# Don't create tables - stamp directly
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'
class TestSQLiteMigrationUpgrade:
"""Tests for upgrade to head workflow."""
@pytest.mark.asyncio
async def test_upgrade_from_baseline_to_head(self, sqlite_engine):
"""
Upgrade from baseline to head applies all migrations.
Workflow:
1. Create tables
2. Stamp baseline
3. Upgrade to head
4. Verify current revision is head
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(sqlite_engine, '0001_baseline')
# Upgrade to head
await run_alembic_upgrade(sqlite_engine, 'head')
# Verify revision
rev = await get_alembic_current(sqlite_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
@pytest.mark.asyncio
async def test_upgrade_idempotent(self, sqlite_engine):
"""
Running upgrade to head multiple times is idempotent.
Workflow:
1. Upgrade to head
2. Get revision
3. Upgrade to head again
4. Verify same revision
"""
# Create tables
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp and upgrade
await run_alembic_stamp(sqlite_engine, '0001_baseline')
await run_alembic_upgrade(sqlite_engine, 'head')
rev1 = await get_alembic_current(sqlite_engine)
# Upgrade again - should be idempotent
await run_alembic_upgrade(sqlite_engine, 'head')
rev2 = await get_alembic_current(sqlite_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestSQLiteMigrationFreshDatabase:
"""Tests for fresh database workflow."""
@pytest.mark.asyncio
async def test_fresh_db_upgrade_from_scratch(self, tmp_path):
"""
Fresh database (no tables) can be upgraded directly to head.
Workflow:
1. Create fresh engine with new DB file
2. Create tables
3. Upgrade to head
4. Verify revision
"""
# Use different DB file for fresh test
fresh_db_file = tmp_path / "test_migrations_fresh.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Create tables on fresh DB
async with fresh_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Upgrade to head directly (no baseline stamp)
await run_alembic_upgrade(fresh_engine, 'head')
# Verify revision
rev = await get_alembic_current(fresh_engine)
assert rev is not None, "Expected a revision on fresh DB"
await fresh_engine.dispose()
@pytest.mark.asyncio
async def test_fresh_db_without_create_all_behavior(self, tmp_path):
"""
Fresh database without create_all - test actual behavior.
This tests what happens when migrations run on truly empty DB.
The behavior is determined by Alembic and migration scripts.
EXPECTED: Either:
1. Migration succeeds (if scripts handle empty DB)
2. Migration fails with specific error (if scripts require tables)
IMPORTANT: This test verifies the ACTUAL behavior, not accepting
any arbitrary failure with try-except pass.
"""
fresh_db_file = tmp_path / "test_empty_migrations.db"
fresh_url = f"sqlite+aiosqlite:///{fresh_db_file}"
fresh_engine = create_async_engine(fresh_url)
# Capture the actual behavior
actual_result = None
actual_error = None
try:
await run_alembic_upgrade(fresh_engine, 'head')
rev = await get_alembic_current(fresh_engine)
actual_result = rev
except Exception as e:
actual_error = e
await fresh_engine.dispose()
# Verify specific behavior - one of two outcomes is expected
if actual_result is not None:
# Migration succeeded - verify revision exists
assert actual_result is not None, "Revision should exist after successful migration"
else:
# Migration failed - verify the error type is known
# Alembic typically raises specific errors for missing tables
assert actual_error is not None, "Error should be captured if migration failed"
# Log the error type for documentation (don't silently pass)
error_type = type(actual_error).__name__
# Acceptable error types for empty DB scenarios
acceptable_errors = [
'OperationalError', # SQLite table not found
'ProgrammingError', # SQLAlchemy errors
'CommandError', # Alembic command errors
]
assert error_type in acceptable_errors, (
f"Unexpected error type: {error_type}. "
f"This may indicate a regression in migration behavior. "
f"Error: {actual_error}"
)
class TestSQLiteMigrationGetCurrent:
"""Tests for get_alembic_current behavior."""
@pytest.mark.asyncio
async def test_get_current_on_unstamped_db_returns_none(self, sqlite_engine):
"""
get_alembic_current returns None for unstamped database.
"""
# Create tables but don't stamp
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# No stamp - should return None
rev = await get_alembic_current(sqlite_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio
async def test_get_current_after_stamp_returns_revision(self, sqlite_engine):
"""
get_alembic_current returns correct revision after stamp.
"""
async with sqlite_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_stamp(sqlite_engine, '0001_baseline')
rev = await get_alembic_current(sqlite_engine)
assert rev == '0001_baseline'

View File

@@ -1,217 +0,0 @@
"""
PostgreSQL migration integration tests.
Tests real Alembic migration behavior using PostgreSQL database.
Marked as slow - requires external PostgreSQL service.
Run locally (requires PostgreSQL):
TEST_POSTGRES_URL=postgresql+asyncpg://user:pass@localhost:5432/test_db \
uv run pytest tests/integration/persistence/test_migrations_postgres.py -q
CI runs automatically with PostgreSQL service container.
"""
from __future__ import annotations
import os
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.persistence.alembic_runner import (
run_alembic_upgrade,
run_alembic_stamp,
get_alembic_current,
)
pytestmark = [pytest.mark.integration, pytest.mark.slow]
@pytest.fixture
def postgres_url():
"""Get PostgreSQL URL from environment."""
url = os.environ.get('TEST_POSTGRES_URL')
if not url:
pytest.skip("TEST_POSTGRES_URL not set")
return url
@pytest.fixture
async def postgres_engine(postgres_url):
"""Create async PostgreSQL engine."""
engine = create_async_engine(postgres_url, isolation_level="AUTOCOMMIT")
yield engine
await engine.dispose()
@pytest.fixture
async def clean_tables(postgres_engine):
"""Drop all tables before and after each test for isolation."""
# Drop all tables before test
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
yield
# Drop all tables after test
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def clean_alembic_version(postgres_engine):
"""Drop alembic_version table before and after each test."""
async with postgres_engine.begin() as conn:
# Drop alembic_version table if exists
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception:
pass
yield
async with postgres_engine.begin() as conn:
try:
await conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
except Exception:
pass
class TestPostgreSQLMigrationBaseline:
"""Tests for baseline stamp workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_sets_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Stamp baseline on existing tables sets correct revision.
Workflow:
1. Create tables via Base.metadata.create_all
2. Stamp with '0001_baseline'
3. Verify current revision is '0001_baseline'
"""
# Create all tables (simulates existing DB created by ORM)
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(postgres_engine, '0001_baseline')
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline', f"Expected '0001_baseline', got {rev}"
@pytest.mark.asyncio
async def test_postgres_baseline_stamp_on_empty_db(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Stamp on empty database (no tables) still sets revision.
This is an edge case - stamping without tables.
"""
# Don't create tables - stamp directly
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'
class TestPostgreSQLMigrationUpgrade:
"""Tests for upgrade to head workflow on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_upgrade_from_baseline_to_head(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Upgrade from baseline to head applies all migrations.
Workflow:
1. Create tables
2. Stamp baseline
3. Upgrade to head
4. Verify current revision is head
"""
# Create tables
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp baseline
await run_alembic_stamp(postgres_engine, '0001_baseline')
# Upgrade to head
await run_alembic_upgrade(postgres_engine, 'head')
# Verify revision
rev = await get_alembic_current(postgres_engine)
assert rev is not None, "Expected a revision after upgrade"
# Head should be the latest migration (0003 for current state)
assert rev.startswith('0003'), f"Expected head to be 0003_*, got {rev}"
@pytest.mark.asyncio
async def test_postgres_upgrade_idempotent(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
Running upgrade to head multiple times is idempotent.
Workflow:
1. Upgrade to head
2. Get revision
3. Upgrade to head again
4. Verify same revision
"""
# Create tables
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Stamp and upgrade
await run_alembic_stamp(postgres_engine, '0001_baseline')
await run_alembic_upgrade(postgres_engine, 'head')
rev1 = await get_alembic_current(postgres_engine)
# Upgrade again - should be idempotent
await run_alembic_upgrade(postgres_engine, 'head')
rev2 = await get_alembic_current(postgres_engine)
assert rev2 == rev1, f"Expected {rev1}, got {rev2}"
class TestPostgreSQLMigrationGetCurrent:
"""Tests for get_alembic_current behavior on PostgreSQL."""
@pytest.mark.asyncio
async def test_postgres_get_current_on_unstamped_db_returns_none(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
get_alembic_current returns None for unstamped database.
"""
# Create tables but don't stamp
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# No stamp - should return None
rev = await get_alembic_current(postgres_engine)
assert rev is None, f"Expected None for unstamped DB, got {rev}"
@pytest.mark.asyncio
async def test_postgres_get_current_after_stamp_returns_revision(
self, postgres_engine, clean_tables, clean_alembic_version
):
"""
get_alembic_current returns correct revision after stamp.
"""
async with postgres_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await run_alembic_stamp(postgres_engine, '0001_baseline')
rev = await get_alembic_current(postgres_engine)
assert rev == '0001_baseline'

View File

@@ -1,5 +0,0 @@
"""
Pipeline integration tests package.
Tests for full pipeline flow using fake provider/runner.
"""

View File

@@ -1,778 +0,0 @@
"""
Pipeline full-flow integration tests.
Tests real pipeline stages with fake runner/provider.
Validates message processing through PreProcessor, Processor, and SendResponseBackStage.
Uses RuntimePipeline directly (not PipelineManager) to avoid DB dependency.
Run: uv run pytest tests/integration/pipeline -q --tb=short
"""
from __future__ import annotations
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock
import sys
from tests.factories import FakeApp, text_query, mock_platform_adapter
from tests.factories.provider import FakeProvider
from tests.factories.platform import FakePlatform
pytestmark = pytest.mark.integration
# ============== FIXTURE FOR SYS.MODULES ISOLATION ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
Break circular import chain for pipeline modules using isolated_sys_modules.
Chain: pipeline → core.app → provider.runner → http_controller → groups/plugins
We mock minimal modules to allow importing RuntimePipeline, StageInstContainer,
and stage classes without triggering full application initialization.
After mocking, we import the stage modules so decorators register them.
"""
from tests.utils.import_isolation import isolated_sys_modules, MockLifecycleControlScope
# Mock core.entities with LifecycleControlScope enum
mock_core_entities = Mock()
mock_core_entities.LifecycleControlScope = MockLifecycleControlScope
# Mock core.app - Application class is referenced but not instantiated
mock_core_app = Mock()
# Mock provider.runner with preregistered_runners list
mock_runner = Mock()
mock_runner.preregistered_runners = [] # Will be populated in tests
# Mock utils.importutil - prevents auto-import of runners
mock_importutil = Mock()
mock_importutil.import_modules_in_pkg = lambda pkg: None
mock_importutil.import_modules_in_pkgs = lambda pkgs: None
# Modules to clear (force re-import after mocking)
clear = [
'langbot.pkg.pipeline.stage',
'langbot.pkg.pipeline.entities',
'langbot.pkg.pipeline.pipelinemgr',
'langbot.pkg.pipeline.preproc.preproc',
'langbot.pkg.pipeline.process.process',
'langbot.pkg.pipeline.process.handler',
'langbot.pkg.pipeline.process.handlers.chat',
'langbot.pkg.pipeline.process.handlers.command',
'langbot.pkg.pipeline.respback.respback',
'langbot.pkg.provider.runner',
]
with isolated_sys_modules(
mocks={
'langbot.pkg.core.entities': mock_core_entities,
'langbot.pkg.core.app': mock_core_app,
'langbot.pkg.provider.runner': mock_runner,
'langbot.pkg.utils.importutil': mock_importutil,
'langbot.pkg.pipeline.controller': Mock(),
'langbot.pkg.pipeline.pipelinemgr': Mock(),
},
clear=clear,
):
# Import stage modules AFTER clearing so decorators register them
from importlib import import_module
# Import stage base first
import_module('langbot.pkg.pipeline.stage')
# Import entities
import_module('langbot.pkg.pipeline.entities')
# Import specific stages to register them
import_module('langbot.pkg.pipeline.preproc.preproc')
import_module('langbot.pkg.pipeline.process.process')
import_module('langbot.pkg.pipeline.respback.respback')
# Import pipelinemgr for RuntimePipeline
import_module('langbot.pkg.pipeline.pipelinemgr')
yield
# ============== FAKE RUNNER ==============
class FakeRunner:
"""Minimal fake runner class for pipeline integration tests.
Note: preregistered_runners expects a CLASS, not an instance.
The handler calls runner_cls(self.ap, query.pipeline_config) to instantiate.
"""
name = 'local-agent'
def __init__(self, app=None, config=None):
self.app = app
self.config = config or {}
self._provider = FakeProvider()
# Instance-level configuration set via class attribute
self._response_text = "fake response"
self._raise_error = None
@classmethod
def returns(cls, text: str):
"""Create a runner class configured to return specific text."""
# We create a subclass with configured response
class ConfiguredRunner(cls):
name = cls.name
_response_text = text
_raise_error = None
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._response_text = text
return ConfiguredRunner
@classmethod
def raises(cls, error: Exception):
"""Create a runner class configured to raise an error."""
class ConfiguredRunner(cls):
name = cls.name
_response_text = None
_raise_error = error
def __init__(self, app=None, config=None):
super().__init__(app, config)
self._raise_error = error
return ConfiguredRunner
async def run(self, query):
"""Run the fake provider and yield messages."""
from langbot_plugin.api.entities.builtin.provider.message import Message
# Use the configured response/error
if self._raise_error:
raise self._raise_error
# Yield a simple message
yield Message(role='assistant', content=self._response_text)
# ============== PIPELINE APP FIXTURE ==============
@pytest.fixture
def pipeline_app():
"""
Create FakeApp with all dependencies required by pipeline stages.
PreProcessor needs: sess_mgr, model_mgr, tool_mgr, plugin_connector
Processor needs: instance_config, plugin_connector
SendResponseBackStage needs: logger
ChatMessageHandler needs: telemetry, survey
"""
app = FakeApp()
# Session/conversation mocks for PreProcessor
mock_session = Mock()
mock_session.launcher_type = Mock()
mock_session.launcher_type.value = 'person'
mock_session.launcher_id = 12345
mock_session.sender_id = 12345
mock_session.use_prompt_name = 'default'
mock_session.using_conversation = None
# Create a simple class to mimic Prompt behavior
class MockPrompt:
def __init__(self, name, messages):
self.name = name
self.messages = messages
def copy(self):
return MockPrompt(self.name, list(self.messages))
# Create real lists for messages
prompt_messages_list = []
messages_list = []
mock_prompt = MockPrompt('default', prompt_messages_list)
mock_conversation = Mock()
mock_conversation.prompt = mock_prompt
mock_conversation.messages = messages_list
mock_conversation.uuid = 'test-conversation-uuid'
mock_conversation.update_time = None
mock_conversation.create_time = None
app.sess_mgr.get_session = AsyncMock(return_value=mock_session)
app.sess_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
# Model mock for PreProcessor
mock_model = Mock()
mock_model.model_entity = Mock()
mock_model.model_entity.uuid = 'test-model-uuid'
mock_model.model_entity.name = 'test-model'
mock_model.model_entity.abilities = ['func_call', 'vision']
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=mock_model)
# Tool manager mock
app.tool_mgr.get_all_tools = AsyncMock(return_value=[])
# Telemetry mock (required by ChatMessageHandler)
app.telemetry = Mock()
app.telemetry.start_send_task = AsyncMock()
# Survey mock
app.survey = None
return app
@pytest.fixture
def fake_platform_adapter():
"""Create a fake platform adapter for outbound capture."""
platform = FakePlatform(stream_output_supported=False)
adapter = mock_platform_adapter(platform)
return adapter, platform
@pytest.fixture
def set_fake_runner():
"""Factory fixture to set a fake runner CLASS in preregistered_runners."""
def _set_runner(runner_cls):
# preregistered_runners expects a list of runner classes
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_cls]
return _set_runner
# ============== PIPELINE CONFIGURATION ==============
def create_minimal_pipeline_config():
"""Create minimal pipeline configuration for tests."""
return {
'ai': {
'runner': {'runner': 'local-agent', 'expire-time': None},
'local-agent': {
'model': {'primary': 'test-model-uuid', 'fallbacks': []},
'prompt': 'default',
'knowledge-bases': [],
},
},
'output': {
'force-delay': {'min': 0.0, 'max': 0.0},
'misc': {
'at-sender': False,
'quote-origin': False,
'exception-handling': 'show-hint',
'failure-hint': 'Request failed.',
},
},
'trigger': {
'misc': {'combine-quote-message': False},
},
}
# ============== HELPER TO PROCESS COROUTINE/GENERATOR ==============
async def collect_processor_results(processor, query, stage_name):
"""
Helper to handle the coroutine -> async_generator pattern.
Processor.process() returns a coroutine that yields an async_generator.
This helper handles both cases like RuntimePipeline does.
"""
result = processor.process(query, stage_name)
# Handle coroutine (await it to get async_generator)
if asyncio.iscoroutine(result):
result = await result
# Now iterate over async_generator
results = []
async for item in result:
results.append(item)
return results
# ============== TESTS ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPipelineStageChainReal:
"""Tests for real pipeline stage chain."""
@pytest.mark.asyncio
async def test_import_pipeline_modules(self):
"""Verify we can import real pipeline modules."""
from langbot.pkg.pipeline import stage, entities
from langbot.pkg.pipeline import pipelinemgr
assert hasattr(stage, 'PipelineStage')
assert hasattr(stage, 'preregistered_stages')
assert hasattr(entities, 'ResultType')
assert hasattr(entities, 'StageProcessResult')
assert hasattr(pipelinemgr, 'RuntimePipeline')
assert hasattr(pipelinemgr, 'StageInstContainer')
@pytest.mark.asyncio
async def test_stage_preregistration(self):
"""Verify stages are preregistered after fixture imports them."""
from langbot.pkg.pipeline import stage
# Check that our target stages are registered
assert 'PreProcessor' in stage.preregistered_stages
assert 'MessageProcessor' in stage.preregistered_stages
assert 'SendResponseBackStage' in stage.preregistered_stages
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestPreProcessorStage:
"""Tests for PreProcessor stage alone."""
@pytest.mark.asyncio
async def test_preproc_continues_on_valid_query(self, pipeline_app, fake_platform_adapter):
"""PreProcessor should return CONTINUE for valid text query."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
adapter, platform = fake_platform_adapter
# Create query with adapter
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector for PromptPreProcessing event
mock_event_ctx = Mock()
mock_event_ctx.event = Mock()
mock_event_ctx.event.default_prompt = [] # Real list
mock_event_ctx.event.prompt = [] # Real list
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create PreProcessor stage
preproc_stage = preproc.PreProcessor(pipeline_app)
result = await preproc_stage.process(query, 'PreProcessor')
assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query.session is not None
assert result.new_query.user_message is not None
@pytest.mark.asyncio
async def test_preproc_sets_user_message(self, pipeline_app, fake_platform_adapter):
"""PreProcessor should set user_message from message_chain."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
adapter, platform = fake_platform_adapter
query = text_query("test message content")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector for PromptPreProcessing event
mock_event_ctx = Mock()
mock_event_ctx.event = Mock()
mock_event_ctx.event.default_prompt = []
mock_event_ctx.event.prompt = []
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
preproc_stage = preproc.PreProcessor(pipeline_app)
result = await preproc_stage.process(query, 'PreProcessor')
assert result.result_type == entities.ResultType.CONTINUE
# Check user_message content
assert result.new_query.user_message is not None
assert result.new_query.user_message.role == 'user'
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestProcessorStage:
"""Tests for MessageProcessor stage."""
@pytest.mark.asyncio
async def test_processor_calls_chat_handler(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""Processor should route to ChatMessageHandler for non-command messages."""
adapter, platform = fake_platform_adapter
# Set fake runner that returns pong
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner)
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
# Collect results using helper
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) >= 1
# Check that resp_messages was populated
assert len(query.resp_messages) >= 1
@pytest.mark.asyncio
async def test_processor_prevent_default_without_reply_interrupts(self, pipeline_app, fake_platform_adapter):
"""Processor should INTERRUPT when plugin prevents default without reply."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector to prevent default without reply
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
@pytest.mark.asyncio
async def test_processor_prevent_default_with_reply_continues(self, pipeline_app, fake_platform_adapter):
"""Processor should CONTINUE when plugin prevents default with reply."""
from langbot.pkg.pipeline import entities
from tests.factories.message import text_chain
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = []
# Create reply chain
reply_chain = text_chain("plugin response")
# Mock plugin_connector to prevent default with reply
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = reply_chain
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(query.resp_messages) == 1
assert query.resp_messages[0] == reply_chain
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestRunnerExceptionFlow:
"""Tests for runner exception handling."""
@pytest.mark.asyncio
async def test_runner_exception_yields_interrupt(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""Runner exception should yield INTERRUPT with error notices."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(ValueError("API Error: rate limit exceeded"))
set_fake_runner(fake_runner)
# Create query with exception handling config
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-hint'
config['output']['misc']['failure-hint'] = 'Request failed.'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert results[0].user_notice == 'Request failed.'
assert results[0].error_notice is not None
@pytest.mark.asyncio
async def test_runner_exception_show_error_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""show-error mode should show actual exception message."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises specific exception
fake_runner = FakeRunner().raises(RuntimeError("Custom runtime error"))
set_fake_runner(fake_runner)
# Create query with show-error mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'show-error'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert 'Custom runtime error' in results[0].user_notice
@pytest.mark.asyncio
async def test_runner_exception_hide_mode(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""hide mode should not show user notice."""
from langbot.pkg.pipeline import entities
adapter, platform = fake_platform_adapter
# Set fake runner that raises exception
fake_runner = FakeRunner().raises(Exception("Hidden error"))
set_fake_runner(fake_runner)
# Create query with hide mode
config = create_minimal_pipeline_config()
config['output']['misc']['exception-handling'] = 'hide'
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = config
# Mock plugin_connector to not prevent default
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
# Create Processor stage
from langbot.pkg.pipeline.process import process
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
assert results[0].user_notice is None
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestSendResponseBackStage:
"""Tests for SendResponseBackStage."""
@pytest.mark.asyncio
async def test_send_response_calls_adapter(self, pipeline_app, fake_platform_adapter):
"""SendResponseBackStage should call adapter.reply_message."""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.respback import respback
from tests.factories.message import text_chain
from langbot_plugin.api.entities.builtin.provider.message import Message
adapter, platform = fake_platform_adapter
# Create query with response message
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Add response message
query.resp_messages = [Message(role='assistant', content='test response')]
query.resp_message_chain = [text_chain('test response')]
# Create SendResponseBackStage
respback_stage = respback.SendResponseBackStage(pipeline_app)
result = await respback_stage.process(query, 'SendResponseBackStage')
assert result.result_type == entities.ResultType.CONTINUE
# Check that adapter was called
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]['type'] == 'reply'
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestStageChainIntegration:
"""Tests for full stage chain (PreProcessor -> Processor -> SendResponseBackStage)."""
@pytest.mark.asyncio
async def test_full_chain_text_message_flow(self, pipeline_app, fake_platform_adapter, set_fake_runner):
"""
Full chain: text message -> PreProcessor -> Processor -> SendResponseBackStage.
Validates:
- PreProcessor sets up session, user_message
- Processor calls runner and populates resp_messages
- SendResponseBackStage calls adapter.reply_message
"""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
from langbot.pkg.pipeline.process import process
from langbot.pkg.pipeline.respback import respback
adapter, platform = fake_platform_adapter
# Set fake runner
fake_runner = FakeRunner().returns("LANGBOT_FAKE_PONG")
set_fake_runner(fake_runner)
# Create query
config = create_minimal_pipeline_config()
query = text_query("ping")
query.adapter = adapter
query.pipeline_config = config
query.resp_messages = []
query.resp_message_chain = []
# Mock plugin_connector for PreProcessor and Processor events
mock_event_ctx_preproc = Mock()
mock_event_ctx_preproc.event = Mock()
mock_event_ctx_preproc.event.default_prompt = []
mock_event_ctx_preproc.event.prompt = []
mock_event_ctx_processor = Mock()
mock_event_ctx_processor.is_prevented_default = Mock(return_value=False)
mock_event_ctx_processor.event = Mock()
mock_event_ctx_processor.event.user_message_alter = None
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
# Create stages
preproc_stage = preproc.PreProcessor(pipeline_app)
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(config)
respback_stage = respback.SendResponseBackStage(pipeline_app)
# Run PreProcessor
result1 = await preproc_stage.process(query, 'PreProcessor')
assert result1.result_type == entities.ResultType.CONTINUE
query = result1.new_query
# Run Processor
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) >= 1
# Build resp_message_chain from resp_messages
from tests.factories.message import text_chain
for resp_msg in query.resp_messages:
if resp_msg.content:
query.resp_message_chain.append(text_chain(resp_msg.content))
# Run SendResponseBackStage
result3 = await respback_stage.process(query, 'SendResponseBackStage')
assert result3.result_type == entities.ResultType.CONTINUE
# Verify adapter was called
outbound = platform.get_outbound_messages()
assert len(outbound) >= 1
@pytest.mark.asyncio
async def test_chain_stops_on_interrupt(self, pipeline_app, fake_platform_adapter):
"""
Chain should stop when a stage returns INTERRUPT.
PreProcessor returns CONTINUE, Processor returns INTERRUPT (prevent_default).
"""
from langbot.pkg.pipeline import entities
from langbot.pkg.pipeline.preproc import preproc
from langbot.pkg.pipeline.process import process
adapter, platform = fake_platform_adapter
# Create query
query = text_query("hello")
query.adapter = adapter
query.pipeline_config = create_minimal_pipeline_config()
# Mock plugin_connector - PreProcessor continues, Processor interrupts
mock_event_ctx_preproc = Mock()
mock_event_ctx_preproc.event = Mock()
mock_event_ctx_preproc.event.default_prompt = []
mock_event_ctx_preproc.event.prompt = []
mock_event_ctx_processor = Mock()
mock_event_ctx_processor.is_prevented_default = Mock(return_value=True)
mock_event_ctx_processor.event = Mock()
mock_event_ctx_processor.event.reply_message_chain = None
pipeline_app.plugin_connector.emit_event = AsyncMock()
pipeline_app.plugin_connector.emit_event.side_effect = [
mock_event_ctx_preproc, # PreProcessor PromptPreProcessing
mock_event_ctx_processor, # Processor NormalMessageReceived
]
# Create stages
preproc_stage = preproc.PreProcessor(pipeline_app)
processor_stage = process.Processor(pipeline_app)
await processor_stage.initialize(query.pipeline_config)
# Run PreProcessor
result1 = await preproc_stage.process(query, 'PreProcessor')
assert result1.result_type == entities.ResultType.CONTINUE
query = result1.new_query
# Run Processor - should INTERRUPT
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
# Chain stops here - no resp_messages
assert len(query.resp_messages) == 0

View File

@@ -1,6 +0,0 @@
"""
Smoke tests package.
Smoke tests verify basic functionality works without testing edge cases.
Run with: uv run pytest tests/smoke/ -q
"""

View File

@@ -1,351 +0,0 @@
"""
Minimal fake flow smoke tests for LangBot.
These tests verify basic component interactions using fake providers and platforms.
Not a full pipeline integration test - tests individual factory components.
For full pipeline tests, see tests/integration/ (planned).
"""
from __future__ import annotations
import pytest
from tests.factories import (
FakeApp,
FakeProvider,
FakePlatform,
text_query,
fake_provider_pong,
fake_model,
mock_platform_adapter,
)
class TestFakeMessageFlow:
"""Smoke tests for fake message flow through pipeline."""
@pytest.mark.asyncio
async def test_fake_app_creation(self):
"""Test FakeApp can be created with all dependencies."""
app = FakeApp()
assert app.logger is not None
assert app.sess_mgr is not None
assert app.model_mgr is not None
assert app.tool_mgr is not None
assert app.persistence_mgr is not None
assert app.query_pool is not None
assert app.instance_config is not None
# Verify default config
assert app.instance_config.data["command"]["prefix"] == ["/", "!"]
assert app.instance_config.data["command"]["enable"] is True
@pytest.mark.asyncio
async def test_fake_provider_returns_text(self):
"""Test FakeProvider returns configured response."""
provider = FakeProvider(default_response="test response")
# Create mock model with provider
model = fake_model(provider=provider)
# Create a simple query
query = text_query("hello")
# Simulate invoke
result = await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
assert result is not None
assert result.role == "assistant"
assert result.content == "test response"
@pytest.mark.asyncio
async def test_fake_provider_pong(self):
"""Test FakeProvider returns LANGBOT_FAKE_PONG marker."""
provider = fake_provider_pong()
model = fake_model(provider=provider)
query = text_query("ping")
result = await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
assert result.content == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_fake_provider_streaming(self):
"""Test FakeProvider streaming response."""
provider = FakeProvider().returns_streaming(["Hello", " World"])
model = fake_model(provider=provider)
query = text_query("hello")
chunks = []
# invoke_llm_stream returns an async generator, don't await it
async for chunk in provider.invoke_llm_stream(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0].content == "Hello"
assert chunks[1].content == " World"
assert chunks[1].is_final is True
@pytest.mark.asyncio
async def test_fake_provider_timeout(self):
"""Test FakeProvider simulates timeout error."""
provider = FakeProvider().timeout()
model = fake_model(provider=provider)
query = text_query("hello")
with pytest.raises(TimeoutError, match="Provider timeout"):
await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
@pytest.mark.asyncio
async def test_fake_provider_rate_limit(self):
"""Test FakeProvider simulates rate limit error."""
provider = FakeProvider().rate_limit()
model = fake_model(provider=provider)
query = text_query("hello")
with pytest.raises(Exception, match="Rate limit exceeded"):
await provider.invoke_llm(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
)
@pytest.mark.asyncio
async def test_fake_provider_captures_requests(self):
"""Test FakeProvider captures request arguments."""
provider = FakeProvider()
model = fake_model(name="gpt-4", provider=provider)
query = text_query("hello")
await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "hello"}],
funcs=[{"name": "test_func"}],
extra_args={"temperature": 0.7},
)
captured = provider.get_captured_requests()
assert len(captured) == 1
assert captured[0]["model"] == "gpt-4"
assert captured[0]["messages"] == [{"role": "user", "content": "hello"}]
assert captured[0]["funcs"] == [{"name": "test_func"}]
assert captured[0]["extra_args"] == {"temperature": 0.7}
@pytest.mark.asyncio
async def test_fake_platform_capture_outbound(self):
"""Test FakePlatform captures outbound messages."""
platform = FakePlatform(bot_account_id="test-bot")
query = text_query("hello")
# Simulate sending reply
from tests.factories.message import text_chain
reply_chain = text_chain("response text")
event = query.message_event
await platform.reply_message(event, reply_chain, quote_origin=False)
# Verify captured
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert outbound[0]["message"] == reply_chain
@pytest.mark.asyncio
async def test_fake_platform_friend_message(self):
"""Test FakePlatform creates friend message events."""
platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_friend_message(
text="hello bot",
sender_id=12345,
nickname="TestUser",
)
assert event.type == "FriendMessage"
assert event.sender.id == 12345
assert event.sender.nickname == "TestUser"
assert str(event.message_chain) == "hello bot"
@pytest.mark.asyncio
async def test_fake_platform_group_message_with_mention(self):
"""Test FakePlatform creates group message with @mention."""
platform = FakePlatform(bot_account_id="test-bot")
event = platform.create_group_message(
text="hello everyone",
sender_id=12345,
group_id=99999,
mention_bot=True,
)
assert event.type == "GroupMessage"
assert event.sender.id == 12345
assert event.group.id == 99999
# Check message chain has @mention
chain = event.message_chain
assert len(chain) >= 2 # At + Plain
@pytest.mark.asyncio
async def test_query_factories_basic(self):
"""Test basic query factory functions."""
# Text query
q1 = text_query("hello world")
assert q1.launcher_type.value == "person"
assert str(q1.message_chain) == "hello world"
# Group query
from tests.factories import group_text_query
q2 = group_text_query("hello group", group_id=88888)
assert q2.launcher_type.value == "group"
assert q2.launcher_id == 88888
# Command query
from tests.factories import command_query
q3 = command_query("help", prefix="/")
assert str(q3.message_chain) == "/help"
# Mention query
from tests.factories import mention_query
q4 = mention_query("hi", target="test-bot", group_id=77777)
assert q4.launcher_type.value == "group"
@pytest.mark.asyncio
async def test_fake_platform_send_failure(self):
"""Test FakePlatform simulates send failure."""
platform = FakePlatform().send_failure()
query = text_query("hello")
from tests.factories.message import text_chain
with pytest.raises(Exception, match="Platform send failure"):
await platform.reply_message(
query.message_event,
text_chain("response"),
)
@pytest.mark.asyncio
async def test_mock_platform_adapter(self):
"""Test mock_platform_adapter helper."""
platform = FakePlatform(bot_account_id="bot-123")
adapter = mock_platform_adapter(platform)
assert adapter.bot_account_id == "bot-123"
assert adapter._fake_platform is platform
# Test reply_message is wired
from tests.factories.message import text_chain
query = text_query("test")
await adapter.reply_message(query.message_event, text_chain("response"))
# Verify platform captured it
assert len(platform.get_outbound_messages()) == 1
class TestMessageFlowIntegration:
"""Minimal fake flow integration tests.
These tests verify component interactions but do NOT run full LangBot pipeline.
For real pipeline tests, integration tests are needed (planned).
"""
@pytest.mark.asyncio
async def test_minimal_message_flow(self):
"""Minimal fake flow test: fake query -> fake provider -> fake platform.
This test verifies:
1. Fake text query is created
2. Fake provider returns LANGBOT_FAKE_PONG
3. Fake platform captures outbound response
4. No unexpected exception
Note: This does NOT run actual LangBot pipeline stages.
"""
# Setup
platform = FakePlatform(bot_account_id="test-bot")
provider = fake_provider_pong()
model = fake_model(provider=provider)
# Create inbound message
query = text_query("ping")
# Simulate provider processing
response = await provider.invoke_llm(
query=query,
model=model,
messages=[{"role": "user", "content": "ping"}],
funcs=[],
extra_args={},
)
# Verify provider returned pong
assert response.content == FakeProvider.PONG_RESPONSE
# Simulate platform sending response
from tests.factories.message import text_chain
reply_chain = text_chain(response.content)
await platform.reply_message(query.message_event, reply_chain)
# Verify platform captured outbound
outbound = platform.get_outbound_messages()
assert len(outbound) == 1
assert outbound[0]["type"] == "reply"
assert str(outbound[0]["message"]) == FakeProvider.PONG_RESPONSE
@pytest.mark.asyncio
async def test_streaming_message_flow(self):
"""Smoke test: streaming message flow."""
platform = FakePlatform().supports_streaming()
provider = FakeProvider().returns_streaming(["Hello", " there"])
model = fake_model(provider=provider)
query = text_query("hi")
chunks = []
async for chunk in provider.invoke_llm_stream(
query=query,
model=model,
messages=[],
funcs=[],
extra_args={},
):
chunks.append(chunk)
# Verify streaming worked
assert len(chunks) == 2
full_content = "".join(c.content for c in chunks)
assert full_content == "Hello there"
# Verify platform supports streaming
assert await platform.is_stream_output_supported() is True

View File

@@ -1,66 +0,0 @@
"""
PoC test for CWE-94: Authenticated RCE via exec() on user-supplied Python code.
The /api/v1/system/debug/exec endpoint passes raw HTTP body to exec(),
allowing arbitrary code execution when debug_mode is True.
This test verifies that:
1. The exec() endpoint is removed from the codebase entirely.
2. No route matches /api/v1/system/debug/exec.
"""
import ast
import pathlib
# Resolve project root (one level up from tests/)
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
VULN_FILE = (
_PROJECT_ROOT
/ "src"
/ "langbot"
/ "pkg"
/ "api"
/ "http"
/ "controller"
/ "groups"
/ "system.py"
)
def test_no_exec_call_in_system_controller():
"""Verify there is no exec() call in system.py that takes user input."""
with open(VULN_FILE, "r") as f:
source = f.read()
tree = ast.parse(source)
exec_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
func = node.func
# Match bare exec() call
if isinstance(func, ast.Name) and func.id == "exec":
exec_calls.append(node.lineno)
assert len(exec_calls) == 0, (
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
"User-supplied code must never be passed to exec()."
)
def test_no_debug_exec_route():
"""Verify the /debug/exec route is not registered."""
with open(VULN_FILE, "r") as f:
source = f.read()
assert "debug/exec" not in source, (
"The /debug/exec route still exists in system.py. "
"This endpoint allows arbitrary code execution and must be removed."
)
if __name__ == "__main__":
test_no_exec_call_in_system_controller()
test_no_debug_exec_route()
print("All tests passed!")

View File

@@ -1,179 +0,0 @@
# 单元测试覆盖率排除说明
## 排除范围
以下外部适配器模块不纳入测试覆盖目标,因为它们需要实际外部环境才能测试:
### 1. 消息平台适配器 (`platform/sources/`)
- **路径**: `src/langbot/pkg/platform/sources/`
- **模块**: aiocqhttp, dingtalk, discord, feishu, gestep, kook, lark, slack, telegram, wecom, wechatpv, wechatmp, qqbot
- **排除原因**: 需要真实消息平台账号和 webhook 连接,无法纯单元测试
- **测试方式**: 需要 mock 平台 API 或集成测试环境
- **状态**: 后续可补充 mock 测试
### 2. LLM Requester (`provider/modelmgr/requesters/`)
- **路径**: `src/langbot/pkg/provider/modelmgr/requesters/`
- **模块**: deepseek, openai, anthropic, gemini, moonshot, ollama, zhipuai 等 20+ 个 requester
- **排除原因**: 需要真实 LLM API 密钥和网络请求,涉及付费 API 调用
- **测试方式**: 需要 mock HTTP 响应或使用 fake LLM server
- **状态**: 后续可补充 mock HTTP 测试
### 3. Agent Runner (`provider/runners/`)
- **路径**: `src/langbot/pkg/provider/runners/`
- **模块**: cozeapi, difysvapi, n8nsvapi, langflowapi, dashscopeapi, localagent, tboxapi
- **排除原因**: 需要真实 Agent 平台Coze、Dify、n8n 等)的 API 连接
- **测试方式**: 需要 mock Agent 平台响应
- **状态**: 后续可补充 mock 测试
### 4. 向量数据库 (`vector/vdbs/`)
- **路径**: `src/langbot/pkg/vector/vdbs/`
- **模块**: chroma, milvus, pgvector, qdrant, seekdb
- **排除原因**: 需要真实向量数据库实例运行
- **测试方式**: 需要 Docker 启动测试数据库或 mock
- **状态**: 后续可补充 mock 测试
---
## 覆盖率计算(排除外部适配器)
### 统计方法
```bash
# 排除外部适配器后计算覆盖率
pytest tests/unit_tests/ --cov=langbot.pkg \
--cov-fail-under=0 \
-o "cov_exclude_patterns=platform/sources/*,provider/modelmgr/requesters/*,provider/runners/*,vector/vdbs/*"
```
### 当前覆盖率(排除后)
| 模块 | 覆盖率 | 状态 |
|------|--------|------|
| `command` | **99%** | ✅ 完成 |
| `entity` | **99%** | ✅ 完成 |
| `vector` | **76%** | ✅ 完成 |
| `survey` | **84%** | ✅ 完成 |
| `pipeline` | **72%** | ✅ 核心流程 |
| `rag` | **66%** | ✅ 完成 |
| `telemetry` | **87%** | ✅ 完成 |
| `storage` | **80%** | ✅ 完成 |
| `provider` | **83%** | ✅ 完成 |
| `discover` | **61%** | ✅ 完成 |
| `config` | **70%** | ✅ 完成 |
| `utils` | **48%** | 🔄 部分完成 |
| `api` | **34%** | 🔄 需补充 controller |
| `platform` | **35%** | 🔄 需补充 adapter base |
| `plugin` | **27%** | 🔄 需补充 handler |
| `core` | **28%** | 🔄 需补充 app 启动 |
| `persistence` | **24%** | 🔄 需补充 mgr |
---
## 后续计划
### 可补充的 Mock 测试(优先级排序)
1. **`provider/modelmgr/requesters/`** (优先级:中)
- 使用 `httpx` mock 测试 API 响应解析
- 测试重试逻辑、错误处理
2. **`provider/runners/`** (优先级:中)
- Mock Agent 平台响应
- 测试 session 管理、错误处理
3. **`platform/sources/`** (优先级:低)
- Mock 平台 webhook 事件
- 测试消息解析、事件处理
4. **`vector/vdbs/`** (优先级:低)
- Mock 向量数据库操作
- 测试 CRUD、查询逻辑
---
## 测试文件结构
```
tests/unit_tests/
├── api/
│ └── service/
│ ├── test_knowledge_service.py # 22 tests ✅
│ └── ...
├── core/
│ ├── test_taskmgr.py # 21 tests ✅
│ ├── test_load_config.py # 21 tests ✅ (含env override)
│ └── ...
├── plugin/
│ ├── test_connector_static.py # 8 tests ✅
│ ├── test_connector_pure.py # 7 tests ✅
│ ├── test_connector_methods.py # 24 tests ✅
│ ├── test_extract_deps.py # 7 tests ✅
│ ├── test_handler_actions.py # 15 tests ✅ (新增)
│ └── ...
├── provider/
│ ├── test_session_manager.py # 11 tests ✅ (新增)
│ ├── test_tool_manager.py # 14 tests ✅ (新增)
│ └── ...
├── rag/
│ ├── test_i18n_conversion.py # 8 tests ✅
│ ├── test_kbmgr.py # 39 tests ✅
│ ├── test_file_storage.py # 21 tests ✅ (新增)
│ └── ...
├── storage/
│ ├── test_s3storage.py # 16 tests ✅ (新增)
│ ├── test_localstorage_path_traversal.py # 11 tests ✅
│ └── ...
├── survey/
│ └── test_survey_manager.py # 22 tests ✅
├── telemetry/
│ └── test_telemetry.py # 25 tests ✅ (重写)
├── vector/
│ ├── test_filter_utils.py # 21 tests ✅
│ ├── test_vdb_filter_conversion.py # 30 tests ✅ (新增)
│ └── ...
├── utils/
│ ├── test_platform.py # 7 tests ✅
│ ├── test_funcschema.py # 9 tests ✅
│ └── ...
├── pipeline/
│ ├── test_ratelimit.py # 12 tests ✅ (新增真实算法)
│ ├── test_msgtrun.py # 9 tests ✅ (强化断言)
│ └── ...
└── persistence/
├── test_serialize_model.py # 6 tests ✅
├── test_database_decorator.py # 7 tests ✅
└── ...
```
---
## 总结
- **总测试数**: 1193 passed
- **总体覆盖率**: 30%
- **核心模块覆盖率**: **51.2%** (6549/12825 语句) - 排除外部适配器
- **外部适配器覆盖率**: 5.6% (535/9483 语句) - 不纳入目标
### 核心模块覆盖率详情
| 模块 | 覆盖率 | 语句数 | 说明 |
|------|--------|--------|------|
| `command` | **99%** | 93 | ✅ 完成 |
| `entity` | **99%** | 335 | ✅ 完成 |
| `vector` | **76%** | 139 | ✅ 完成 (新增filter转换测试) |
| `survey` | **84%** | 95 | ✅ 完成 |
| `pipeline` | **72%** | 1761 | ✅ 核心流程 (新增算法测试) |
| `rag` | **66%** | 347 | ✅ 完成 (新增ZIP处理测试) |
| `telemetry` | **87%** | 70 | ✅ 完成 (重写假测试) |
| `storage` | **80%** | 170 | ✅ 完成 (新增S3测试) |
| `provider` | **83%** | 854 | ✅ 完成 (新增Session/Tool测试) |
| `discover` | **61%** | 188 | ✅ 完成 |
| `config` | **70%** | 198 | ✅ 完成 |
| `utils` | **48%** | 478 | 🔄 部分完成 |
| `api` | **34%** | 4061 | 🔄 需补充 controller |
| `platform` | **35%** | 433 | 🔄 需补充 adapter base |
| `plugin` | **27%** | 815 | 🔄 需补充 handler (新增action测试) |
| `core` | **28%** | 1289 | 🔄 需补充 app 启动 |
| `persistence` | **24%** | 1099 | 🔄 需补充 mgr |
外部适配器测试需要 mock 环境或集成测试,不属于纯单元测试范畴。

View File

@@ -1 +0,0 @@
"""Unit tests for LangBot API HTTP service layer."""

View File

@@ -1,62 +0,0 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from sqlalchemy.sql.dml import Update
from langbot.pkg.api.http.service.bot import BotService
class _FakeResult:
def __init__(self, value):
self.value = value
def first(self):
return self.value
class _PersistenceManager:
def __init__(self):
self.update_values = None
async def execute_async(self, statement):
if isinstance(statement, Update):
self.update_values = {
key: value for key, value in statement.compile().params.items() if not key.startswith('uuid_')
}
return None
return _FakeResult(SimpleNamespace(name='Updated Pipeline'))
async def test_update_bot_copies_input_before_filtering_and_setting_pipeline_name():
persistence_mgr = _PersistenceManager()
runtime_bot = SimpleNamespace(enable=False)
platform_mgr = SimpleNamespace(
remove_bot=AsyncMock(),
load_bot=AsyncMock(return_value=runtime_bot),
)
ap = SimpleNamespace(
persistence_mgr=persistence_mgr,
platform_mgr=platform_mgr,
sess_mgr=SimpleNamespace(session_list=[]),
)
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'bot-1'})
payload = {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
await service.update_bot('bot-1', payload)
assert payload == {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
assert persistence_mgr.update_values == {
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
'use_pipeline_name': 'Updated Pipeline',
}

View File

@@ -1,16 +0,0 @@
"""Unit tests for API HTTP service layer.
Tests real service business logic with mocked dependencies:
- persistence_mgr (database operations)
- model_mgr (runtime model management)
- platform_mgr (platform management)
- plugin_connector (plugin runtime)
- adjacent services (cross-service calls)
Does NOT:
- Start real Quart server
- Access real database
- Call real provider/platform/network
Uses tests.factories.FakeApp as base mock application.
"""

View File

@@ -1,429 +0,0 @@
"""
Unit tests for ApiKeyService.
Tests API key CRUD operations with mocked persistence layer.
Source: src/langbot/pkg/api/http/service/apikey.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch
from types import SimpleNamespace
from langbot.pkg.api.http.service.apikey import ApiKeyService
from langbot.pkg.entity.persistence.apikey import ApiKey
pytestmark = pytest.mark.asyncio
class TestApiKeyServiceGetApiKeys:
"""Tests for get_api_keys method."""
async def test_get_api_keys_empty_list(self):
"""Returns empty list when no API keys exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.all = Mock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
if entity
else {}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_keys()
# Verify
assert result == []
ap.persistence_mgr.execute_async.assert_called_once()
async def test_get_api_keys_returns_serialized_list(self):
"""Returns serialized list of API keys."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Create mock API key entities
key1 = Mock(spec=ApiKey)
key1.id = 1
key1.name = 'Test Key 1'
key1.key = 'lbk_test_key_1'
key1.description = 'First test key'
key2 = Mock(spec=ApiKey)
key2.id = 2
key2.name = 'Test Key 2'
key2.key = 'lbk_test_key_2'
key2.description = 'Second test key'
mock_result = Mock()
mock_result.all = Mock(return_value=[key1, key2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_keys()
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Test Key 1'
assert result[1]['name'] == 'Test Key 2'
class TestApiKeyServiceCreateApiKey:
"""Tests for create_api_key method."""
async def test_create_api_key_generates_key_with_prefix(self):
"""Creates API key with 'lbk_' prefix."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_key = Mock(spec=ApiKey)
created_key.id = 1
created_key.name = 'New Key'
created_key.key = 'lbk_fixed-token'
created_key.description = 'Test description'
select_result = Mock()
select_result.first = Mock(return_value=created_key)
insert_params = []
async def mock_execute(query):
params = query.compile().params
if {'name', 'key', 'description'}.issubset(params):
insert_params.append(params)
return Mock()
return select_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': 1,
'name': entity.name,
'key': entity.key,
'description': entity.description,
}
)
service = ApiKeyService(ap)
with patch('langbot.pkg.api.http.service.apikey.secrets.token_urlsafe', return_value='fixed-token'):
result = await service.create_api_key('New Key', 'Test description')
assert insert_params == [
{'name': 'New Key', 'key': 'lbk_fixed-token', 'description': 'Test description'}
]
assert result['key'].startswith('lbk_')
assert result['key'] == 'lbk_fixed-token'
assert result['name'] == 'New Key'
assert result['description'] == 'Test description'
async def test_create_api_key_without_description(self):
"""Creates API key with empty description when not provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_key = Mock(spec=ApiKey)
created_key.id = 1
created_key.name = 'No Desc Key'
created_key.key = 'lbk_no_desc_key'
created_key.description = ''
select_result = Mock()
select_result.first = Mock(return_value=created_key)
insert_result = Mock()
async def mock_execute(query):
if hasattr(query, 'values'):
return insert_result
return select_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'No Desc Key',
'key': 'lbk_no_desc_key',
'description': '',
}
)
service = ApiKeyService(ap)
# Execute
result = await service.create_api_key('No Desc Key')
# Verify
assert result['description'] == ''
class TestApiKeyServiceGetApiKey:
"""Tests for get_api_key method."""
async def test_get_api_key_by_id_found(self):
"""Returns API key when found by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
key = Mock(spec=ApiKey)
key.id = 1
key.name = 'Found Key'
key.key = 'lbk_found_key'
key.description = 'Found'
mock_result = Mock()
mock_result.first = Mock(return_value=key)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'Found Key',
'key': 'lbk_found_key',
'description': 'Found',
}
)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(1)
# Verify
assert result is not None
assert result['id'] == 1
assert result['name'] == 'Found Key'
async def test_get_api_key_by_id_not_found(self):
"""Returns None when API key not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(999)
# Verify
assert result is None
async def test_get_api_key_by_id_zero(self):
"""Handles ID=0 (edge case) correctly."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.get_api_key(0)
# Verify - should return None (no key with ID 0)
assert result is None
class TestApiKeyServiceVerifyApiKey:
"""Tests for verify_api_key method."""
async def test_verify_api_key_valid(self):
"""Returns True for valid API key."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
key = Mock(spec=ApiKey)
mock_result = Mock()
mock_result.first = Mock(return_value=key)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('lbk_valid_key')
# Verify
assert result is True
async def test_verify_api_key_invalid(self):
"""Returns False for invalid API key."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('lbk_invalid_key')
# Verify
assert result is False
async def test_verify_api_key_empty_string(self):
"""Returns False for empty key string."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('')
# Verify
assert result is False
async def test_verify_api_key_unknown_key(self):
"""Returns False when the key is not present in persistence."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ApiKeyService(ap)
# Execute
result = await service.verify_api_key('unknown_key')
# Verify
assert result is False
class TestApiKeyServiceDeleteApiKey:
"""Tests for delete_api_key method."""
async def test_delete_api_key_by_id(self):
"""Deletes API key by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.delete_api_key(1)
# Verify - execute_async was called (delete operation)
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_api_key_nonexistent_id(self):
"""Delete operation completes even for nonexistent ID (no error raised)."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute - should not raise error
await service.delete_api_key(999)
# Verify - execute_async was called regardless
ap.persistence_mgr.execute_async.assert_called_once()
class TestApiKeyServiceUpdateApiKey:
"""Tests for update_api_key method."""
async def test_update_api_key_name_only(self):
"""Updates only the name field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, name='Updated Name')
# Verify - execute_async was called with update
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_description_only(self):
"""Updates only the description field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, description='Updated description')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_both_fields(self):
"""Updates both name and description."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1, name='New Name', description='New description')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_api_key_no_fields(self):
"""Does nothing when no fields provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = ApiKeyService(ap)
# Execute
await service.update_api_key(1)
# Verify - no execute call since no update_data
ap.persistence_mgr.execute_async.assert_not_called()

View File

@@ -1,662 +0,0 @@
"""
Unit tests for BotService.
Tests bot CRUD operations with mocked persistence and runtime managers.
Source: src/langbot/pkg/api/http/service/bot.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch
from types import SimpleNamespace
import uuid
from langbot.pkg.api.http.service.bot import BotService
from langbot.pkg.entity.persistence.bot import Bot
pytestmark = pytest.mark.asyncio
def _create_mock_bot(
bot_uuid: str = None,
name: str = 'Test Bot',
description: str = 'Test Description',
adapter: str = 'telegram',
adapter_config: dict = None,
enable: bool = True,
use_pipeline_uuid: str = None,
use_pipeline_name: str = None,
) -> Mock:
"""Helper to create mock Bot entity."""
bot = Mock(spec=Bot)
bot.uuid = bot_uuid or str(uuid.uuid4())
bot.name = name
bot.description = description
bot.adapter = adapter
bot.adapter_config = adapter_config or {'token': 'test_token'}
bot.enable = enable
bot.use_pipeline_uuid = use_pipeline_uuid
bot.use_pipeline_name = use_pipeline_name
bot.pipeline_routing_rules = []
return bot
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestBotServiceGetBots:
"""Tests for get_bots method."""
async def test_get_bots_empty_list(self):
"""Returns empty list when no bots exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots()
# Verify
assert result == []
async def test_get_bots_returns_list_with_secrets(self):
"""Returns bot list including adapter_config by default."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
bot2 = _create_mock_bot(bot_uuid='uuid-2', name='Bot 2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots(include_secret=True)
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Bot 1'
assert result[0]['adapter_config'] is not None
async def test_get_bots_masks_secrets(self):
"""Returns bot list without adapter_config when include_secret=False."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot1 = _create_mock_bot(bot_uuid='uuid-1', name='Bot 1')
mock_result = _create_mock_result([bot1])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity, masked_columns=None: {
'uuid': entity.uuid,
'name': entity.name,
'adapter': entity.adapter,
'adapter_config': entity.adapter_config if 'adapter_config' not in (masked_columns or []) else None,
}
)
service = BotService(ap)
# Execute
result = await service.get_bots(include_secret=False)
# Verify - adapter_config should be masked
assert result[0]['adapter_config'] is None
class TestBotServiceGetBot:
"""Tests for get_bot method."""
async def test_get_bot_by_uuid_found(self):
"""Returns bot when found by UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
bot = _create_mock_bot(bot_uuid='test-uuid', name='Found Bot')
mock_result = _create_mock_result(first_item=bot)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'name': 'Found Bot',
'adapter': 'telegram',
}
)
service = BotService(ap)
# Execute
result = await service.get_bot('test-uuid')
# Verify
assert result is not None
assert result['uuid'] == 'test-uuid'
assert result['name'] == 'Found Bot'
async def test_get_bot_by_uuid_not_found(self):
"""Returns None when bot not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = BotService(ap)
# Execute
result = await service.get_bot('nonexistent-uuid')
# Verify
assert result is None
class TestBotServiceGetRuntimeBotInfo:
"""Tests for get_runtime_bot_info method."""
async def test_get_runtime_bot_info_bot_not_found_raises(self):
"""Raises Exception when bot not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = BotService(ap)
# Mock get_bot to return None
service.get_bot = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.get_runtime_bot_info('nonexistent-uuid')
async def test_get_runtime_bot_info_returns_webhook_for_wecom(self):
"""Returns webhook URL for wecom adapter."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'api': {
'webhook_prefix': 'http://127.0.0.1:5300',
'extra_webhook_prefix': 'http://extra.example.com',
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
bot_data = {
'uuid': 'wecom-uuid',
'name': 'WeCom Bot',
'adapter': 'wecom',
'adapter_config': {'token': 'test'},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('wecom-uuid')
# Verify
assert result['adapter_runtime_values']['webhook_url'] == '/bots/wecom-uuid'
assert result['adapter_runtime_values']['webhook_full_url'] == 'http://127.0.0.1:5300/bots/wecom-uuid'
async def test_get_runtime_bot_info_no_webhook_for_telegram(self):
"""Returns no webhook URL for non-webhook adapters like telegram."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'api': {}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
bot_data = {
'uuid': 'telegram-uuid',
'name': 'Telegram Bot',
'adapter': 'telegram',
'adapter_config': {'token': 'test'},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('telegram-uuid')
# Verify - no webhook for telegram
assert result['adapter_runtime_values']['webhook_url'] is None
assert result['adapter_runtime_values']['webhook_full_url'] is None
async def test_get_runtime_bot_info_with_runtime_bot(self):
"""Returns bot_account_id when runtime bot exists."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'api': {}}
ap.platform_mgr = SimpleNamespace()
# Mock runtime bot with adapter
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.bot_account_id = 'runtime-account-123'
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
bot_data = {
'uuid': 'runtime-uuid',
'name': 'Runtime Bot',
'adapter': 'telegram',
'adapter_config': {},
}
service = BotService(ap)
service.get_bot = AsyncMock(return_value=bot_data)
# Execute
result = await service.get_runtime_bot_info('runtime-uuid')
# Verify
assert result['adapter_runtime_values']['bot_account_id'] == 'runtime-account-123'
class TestBotServiceCreateBot:
"""Tests for create_bot method."""
async def test_create_bot_max_limit_reached_raises(self):
"""Raises ValueError when max_bots limit reached."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': 2
}
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock get_bots to return 2 bots already
bot1 = _create_mock_bot(bot_uuid='uuid-1')
bot2 = _create_mock_bot(bot_uuid='uuid-2')
mock_result = _create_mock_result([bot1, bot2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Bot 1'}
)
service = BotService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of bots'):
await service.create_bot({'name': 'New Bot'})
async def test_create_bot_no_limit(self):
"""Creates bot without limit check when max_bots=-1."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_bots': -1 # No limit
}
}
}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock pipeline query
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=None)
# Mock bot query after insert
bot_result = Mock()
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count <= 2:
return pipeline_result # First call: check pipeline
elif call_count == 3:
return Mock() # Insert
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Bot'}
)
service = BotService(ap)
# Execute
bot_uuid = await service.create_bot({'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}})
# Verify
assert bot_uuid is not None
assert len(bot_uuid) == 36 # UUID format
async def test_create_bot_sets_default_pipeline(self):
"""Sets default pipeline when one exists."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_bots': -1}}}
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.load_bot = AsyncMock()
# Mock default pipeline
mock_pipeline = SimpleNamespace()
mock_pipeline.uuid = 'default-pipeline-uuid'
mock_pipeline.name = 'Default Pipeline'
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=mock_pipeline)
# Mock bot after insert
bot_result = Mock()
bot_result.first = Mock(return_value=_create_mock_bot())
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return pipeline_result # Check default pipeline
elif call_count == 2:
return Mock() # Insert
return bot_result # Get bot
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'new-uuid',
'name': 'New Bot',
'use_pipeline_uuid': 'default-pipeline-uuid',
'use_pipeline_name': 'Default Pipeline',
}
)
service = BotService(ap)
# Execute
bot_data = {'name': 'New Bot', 'adapter': 'telegram', 'adapter_config': {}}
bot_uuid = await service.create_bot(bot_data)
# Verify - pipeline uuid and name were set
assert 'use_pipeline_uuid' in bot_data
assert 'use_pipeline_name' in bot_data
assert bot_uuid is not None # Verify UUID was returned
class TestBotServiceUpdateBot:
"""Tests for update_bot method."""
async def test_update_bot_removes_uuid_from_data(self):
"""Does not persist caller-provided uuid in update payload."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
# Mock pipeline query - not updating pipeline
ap.persistence_mgr.execute_async = AsyncMock()
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
# Create mock runtime bot
runtime_bot = SimpleNamespace()
runtime_bot.enable = False
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
# Execute
update_data = {'uuid': 'should-be-removed', 'name': 'Updated Name'}
await service.update_bot('test-uuid', update_data)
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
assert update_params['name'] == 'Updated Name'
assert 'should-be-removed' not in update_params.values()
async def test_update_bot_pipeline_not_found_raises(self):
"""Raises Exception when updating with nonexistent pipeline UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Mock pipeline query returns None
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=pipeline_result)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Pipeline not found'):
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'nonexistent-pipeline'})
async def test_update_bot_sets_pipeline_name(self):
"""Sets use_pipeline_name when updating use_pipeline_uuid."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
# Mock pipeline query
mock_pipeline = SimpleNamespace()
mock_pipeline.name = 'Updated Pipeline'
pipeline_result = Mock()
pipeline_result.first = Mock(return_value=mock_pipeline)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return pipeline_result
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'test-uuid'})
runtime_bot = SimpleNamespace()
runtime_bot.enable = False
ap.platform_mgr.load_bot = AsyncMock(return_value=runtime_bot)
# Execute
await service.update_bot('test-uuid', {'use_pipeline_uuid': 'pipeline-uuid'})
update_params = ap.persistence_mgr.execute_async.await_args_list[1].args[0].compile().params
assert update_params['use_pipeline_uuid'] == 'pipeline-uuid'
assert update_params['use_pipeline_name'] == 'Updated Pipeline'
class TestBotServiceDeleteBot:
"""Tests for delete_bot method."""
async def test_delete_bot_calls_remove_and_delete(self):
"""Calls both platform_mgr.remove_bot and persistence delete."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
service = BotService(ap)
# Execute
await service.delete_bot('test-uuid')
# Verify
ap.platform_mgr.remove_bot.assert_called_once_with('test-uuid')
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_bot_nonexistent_uuid(self):
"""Delete operation completes even for nonexistent UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.remove_bot = AsyncMock()
service = BotService(ap)
# Execute - should not raise
await service.delete_bot('nonexistent-uuid')
# Verify - both called regardless
ap.platform_mgr.remove_bot.assert_called_once()
class TestBotServiceListEventLogs:
"""Tests for list_event_logs method."""
async def test_list_event_logs_bot_not_found_raises(self):
"""Raises Exception when runtime bot not found."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.list_event_logs('nonexistent-uuid', 0, 10)
async def test_list_event_logs_returns_logs(self):
"""Returns logs from runtime bot logger."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
# Mock runtime bot with logger
runtime_bot = SimpleNamespace()
runtime_bot.logger = SimpleNamespace()
runtime_bot.logger.get_logs = AsyncMock(return_value=(
[SimpleNamespace(to_json=Mock(return_value={'msg': 'log1'}))],
5
))
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute
logs, total = await service.list_event_logs('bot-uuid', 0, 10)
# Verify
assert len(logs) == 1
assert logs[0] == {'msg': 'log1'}
assert total == 5
class TestBotServiceSendMessage:
"""Tests for send_message method."""
async def test_send_message_bot_not_found_raises(self):
"""Raises Exception when bot not found."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=None)
service = BotService(ap)
# Execute & Verify
with pytest.raises(Exception, match='Bot not found'):
await service.send_message('nonexistent-uuid', 'group', '123', {'test': 'data'})
async def test_send_message_invalid_message_chain_raises(self):
"""Raises Exception when message_chain_data is invalid."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.send_message = AsyncMock()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute & Verify - invalid format should raise
with pytest.raises(Exception, match='Invalid message_chain format'):
await service.send_message('bot-uuid', 'group', '123', {'invalid': 'format'})
async def test_send_message_valid_call(self):
"""Sends message through adapter when all valid."""
# Setup
ap = SimpleNamespace()
ap.platform_mgr = SimpleNamespace()
runtime_bot = SimpleNamespace()
runtime_bot.adapter = SimpleNamespace()
runtime_bot.adapter.send_message = AsyncMock()
ap.platform_mgr.get_bot_by_uuid = AsyncMock(return_value=runtime_bot)
service = BotService(ap)
# Execute with valid message chain format
message_chain_data = {
'messages': [
{'type': 'text', 'data': {'text': 'Hello'}}
]
}
# Patch the import location - the module imports inside the function
with patch('langbot_plugin.api.entities.builtin.platform.message.MessageChain') as MockMessageChain:
mock_chain = Mock()
MockMessageChain.model_validate = Mock(return_value=mock_chain)
await service.send_message('bot-uuid', 'group', '123', message_chain_data)
# Verify adapter.send_message was called
runtime_bot.adapter.send_message.assert_called_once_with('group', '123', mock_chain)

View File

@@ -1,397 +0,0 @@
"""Unit tests for API knowledge service.
Tests cover:
- Knowledge base CRUD operations
- Capability checking
- Knowledge engine discovery
- File operations
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
def get_knowledge_service_module():
"""Lazy import to avoid circular import issues."""
return import_module('langbot.pkg.api.http.service.knowledge')
def create_mock_app():
"""Create mock Application for testing."""
mock_app = Mock()
mock_app.logger = Mock()
mock_app.rag_mgr = AsyncMock()
mock_app.persistence_mgr = AsyncMock()
mock_app.persistence_mgr.execute_async = AsyncMock()
mock_app.persistence_mgr.serialize_model = Mock(return_value={})
mock_app.plugin_connector = AsyncMock()
mock_app.plugin_connector.is_enable_plugin = True
return mock_app
class TestKnowledgeServiceInit:
"""Tests for KnowledgeService initialization."""
def test_init_stores_app_reference(self):
"""Test that __init__ stores Application reference."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
assert service.ap is mock_app
class TestGetKnowledgeBases:
"""Tests for get_knowledge_bases method."""
@pytest.mark.asyncio
async def test_returns_all_kb_details(self):
"""Test that it returns all knowledge base details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(
return_value=[{'uuid': 'kb1', 'name': 'KB1'}]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
assert len(result) == 1
assert result[0]['uuid'] == 'kb1'
@pytest.mark.asyncio
async def test_returns_empty_list_when_no_kbs(self):
"""Test that it returns empty list when no knowledge bases."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_all_knowledge_base_details = AsyncMock(return_value=[])
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_bases()
assert result == []
class TestGetKnowledgeBase:
"""Tests for get_knowledge_base method."""
@pytest.mark.asyncio
async def test_returns_kb_details_by_uuid(self):
"""Test that it returns specific KB details."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'KB1'}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('kb1')
assert result['uuid'] == 'kb1'
@pytest.mark.asyncio
async def test_returns_none_when_not_found(self):
"""Test that it returns None when KB not found."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_knowledge_base('nonexistent')
assert result is None
class TestCreateKnowledgeBase:
"""Tests for create_knowledge_base method."""
@pytest.mark.asyncio
async def test_creates_kb_with_required_fields(self):
"""Test creating KB with required plugin ID."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_kb = Mock()
mock_kb.uuid = 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
service = knowledge_module.KnowledgeService(mock_app)
kb_data = {
'name': 'Test KB',
'knowledge_engine_plugin_id': 'author/engine',
'description': 'Test description',
}
result = await service.create_knowledge_base(kb_data)
assert result == 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base.assert_called_once()
@pytest.mark.asyncio
async def test_raises_when_missing_plugin_id(self):
"""Test that ValueError is raised when plugin ID missing."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(ValueError) as exc_info:
await service.create_knowledge_base({'name': 'Test'})
assert 'knowledge_engine_plugin_id is required' in str(exc_info.value)
@pytest.mark.asyncio
async def test_creates_with_default_name(self):
"""Test that KB is created with default name if not provided."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_kb = Mock()
mock_kb.uuid = 'new_kb_uuid'
mock_app.rag_mgr.create_knowledge_base = AsyncMock(return_value=mock_kb)
service = knowledge_module.KnowledgeService(mock_app)
await service.create_knowledge_base({
'knowledge_engine_plugin_id': 'author/engine'
})
# Check that default name 'Untitled' was used
call_args = mock_app.rag_mgr.create_knowledge_base.call_args
assert call_args.kwargs['name'] == 'Untitled'
class TestUpdateKnowledgeBase:
"""Tests for update_knowledge_base method."""
@pytest.mark.asyncio
async def test_updates_mutable_fields_only(self):
"""Test that only mutable fields are updated."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'uuid': 'kb1', 'name': 'Updated'}
)
mock_app.rag_mgr.remove_knowledge_base_from_runtime = AsyncMock()
mock_app.rag_mgr.load_knowledge_base = AsyncMock()
service = knowledge_module.KnowledgeService(mock_app)
# Pass both mutable and immutable fields
await service.update_knowledge_base('kb1', {
'name': 'New Name',
'description': 'New desc',
'uuid': 'should_be_filtered', # immutable
})
# Check that only mutable fields were passed to update
call_args = mock_app.persistence_mgr.execute_async.call_args
assert call_args is not None
@pytest.mark.asyncio
async def test_returns_early_when_no_mutable_fields(self):
"""Test that update returns early when no mutable fields provided."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
service = knowledge_module.KnowledgeService(mock_app)
# Pass only immutable fields
await service.update_knowledge_base('kb1', {'uuid': 'should_be_filtered'})
# No DB update should be called
mock_app.persistence_mgr.execute_async.assert_not_called()
class TestCheckDocCapability:
"""Tests for _check_doc_capability method."""
@pytest.mark.asyncio
async def test_passes_when_capability_supported(self):
"""Test that check passes when doc_ingestion capability exists."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'knowledge_engine': {'capabilities': ['doc_ingestion']}}
)
service = knowledge_module.KnowledgeService(mock_app)
await service._check_doc_capability('kb1', 'document upload')
# No exception raised means success
@pytest.mark.asyncio
async def test_raises_when_kb_not_found(self):
"""Test that Exception is raised when KB not found."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(return_value=None)
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(Exception) as exc_info:
await service._check_doc_capability('nonexistent', 'test operation')
assert 'Knowledge base not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_raises_when_capability_not_supported(self):
"""Test that Exception is raised when doc_ingestion not in capabilities."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.rag_mgr.get_knowledge_base_details = AsyncMock(
return_value={'knowledge_engine': {'capabilities': ['other_capability']}}
)
service = knowledge_module.KnowledgeService(mock_app)
with pytest.raises(Exception) as exc_info:
await service._check_doc_capability('kb1', 'document upload')
assert 'does not support document upload' in str(exc_info.value)
class TestListKnowledgeEngines:
"""Tests for list_knowledge_engines method."""
@pytest.mark.asyncio
async def test_returns_engines_from_plugin_connector(self):
"""Test that it returns knowledge engines from plugin connector."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
return_value=[{'id': 'engine1', 'name': 'Engine 1'}]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert len(result) == 1
assert result[0]['id'] == 'engine1'
@pytest.mark.asyncio
async def test_returns_empty_when_plugin_disabled(self):
"""Test that it returns empty list when plugin disabled."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.is_enable_plugin = False
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert result == []
@pytest.mark.asyncio
async def test_returns_empty_on_exception(self):
"""Test that it returns empty list and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_knowledge_engines = AsyncMock(
side_effect=Exception('Connection error')
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_knowledge_engines()
assert result == []
mock_app.logger.warning.assert_called_once()
class TestListParsers:
"""Tests for list_parsers method."""
@pytest.mark.asyncio
async def test_returns_all_parsers(self):
"""Test that it returns all parsers when no MIME type filter."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_parsers = AsyncMock(
return_value=[
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers()
assert len(result) == 2
@pytest.mark.asyncio
async def test_filters_by_mime_type(self):
"""Test that it filters parsers by MIME type."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.list_parsers = AsyncMock(
return_value=[
{'id': 'parser1', 'supported_mime_types': ['text/plain']},
{'id': 'parser2', 'supported_mime_types': ['application/pdf']},
]
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers(mime_type='application/pdf')
assert len(result) == 1
assert result[0]['id'] == 'parser2'
@pytest.mark.asyncio
async def test_returns_empty_when_plugin_disabled(self):
"""Test that it returns empty list when plugin disabled."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.is_enable_plugin = False
service = knowledge_module.KnowledgeService(mock_app)
result = await service.list_parsers()
assert result == []
class TestGetEngineSchemas:
"""Tests for get_engine_creation_schema and get_engine_retrieval_schema."""
@pytest.mark.asyncio
async def test_returns_creation_schema(self):
"""Test that it returns creation schema for engine."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
return_value={'properties': {'name': {'type': 'string'}}}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert 'properties' in result
@pytest.mark.asyncio
async def test_returns_retrieval_schema(self):
"""Test that it returns retrieval schema for engine."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_retrieval_schema = AsyncMock(
return_value={'properties': {'top_k': {'type': 'integer'}}}
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_retrieval_schema('author/engine')
assert 'properties' in result
@pytest.mark.asyncio
async def test_returns_empty_dict_on_exception(self):
"""Test that it returns empty dict and logs warning on exception."""
knowledge_module = get_knowledge_service_module()
mock_app = create_mock_app()
mock_app.plugin_connector.get_rag_creation_schema = AsyncMock(
side_effect=Exception('Plugin error')
)
service = knowledge_module.KnowledgeService(mock_app)
result = await service.get_engine_creation_schema('author/engine')
assert result == {}
mock_app.logger.warning.assert_called_once()

View File

@@ -1,824 +0,0 @@
"""
Unit tests for MaintenanceService.
Tests storage maintenance and diagnostics including:
- Cleanup expired files
- Storage analysis
- File counting and sizing
- Monitoring counts
- Binary storage stats
Source: src/langbot/pkg/api/http/service/maintenance.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from types import SimpleNamespace
import datetime
from pathlib import Path
from langbot.pkg.api.http.service.maintenance import MaintenanceService
pytestmark = pytest.mark.asyncio
def _create_mock_result(scalar_value=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.scalar = Mock(return_value=scalar_value)
return result
class TestMaintenanceServiceCleanupExpiredFiles:
"""Tests for cleanup_expired_files method."""
async def test_cleanup_expired_files_default_retention(self):
"""Uses default retention days when config not set."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.storage_mgr = SimpleNamespace()
# Create a proper mock object with __class__.__name__
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods - one is async, one is not
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async!
# Execute
result = await service.cleanup_expired_files()
# Verify - returns counts
assert 'uploaded_files' in result
assert 'log_files' in result
assert result['uploaded_files'] == 0
assert result['log_files'] == 0
async def test_cleanup_expired_files_custom_retention(self):
"""Uses custom retention days from config."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'storage': {
'cleanup': {
'uploaded_file_retention_days': 14,
'log_retention_days': 7,
}
}
}
ap.storage_mgr = SimpleNamespace()
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=2)
service._cleanup_expired_log_files = Mock(return_value=3) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify
assert result['uploaded_files'] == 2
assert result['log_files'] == 3
async def test_cleanup_expired_files_s3_provider(self):
"""Handles S3StorageProvider correctly."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.storage_mgr = SimpleNamespace()
# Mock S3 provider
s3_provider = MagicMock()
s3_provider.__class__.__name__ = 'S3StorageProvider'
s3_provider.delete = AsyncMock()
ap.storage_mgr.storage_provider = s3_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=1)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify
assert result['uploaded_files'] == 1
assert result['log_files'] == 0
async def test_cleanup_expired_files_invalid_retention(self):
"""Uses default for invalid retention config."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'storage': {
'cleanup': {
'uploaded_file_retention_days': 'invalid', # Invalid
'log_retention_days': 0, # Invalid (less than 1)
}
}
}
ap.storage_mgr = SimpleNamespace()
storage_provider = MagicMock()
storage_provider.__class__.__name__ = 'LocalStorageProvider'
ap.storage_mgr.storage_provider = storage_provider
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Mock the internal cleanup methods
service._cleanup_expired_uploaded_files = AsyncMock(return_value=0)
service._cleanup_expired_log_files = Mock(return_value=0) # NOT async
# Execute
result = await service.cleanup_expired_files()
# Verify - warning logged, defaults used
assert ap.logger.warning.called
assert 'uploaded_files' in result
class TestMaintenanceServiceGetStorageAnalysis:
"""Tests for get_storage_analysis method."""
async def test_get_storage_analysis_basic(self):
"""Returns basic storage analysis."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}
}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = SimpleNamespace()
ap.task_mgr.get_stats = Mock(return_value={'running': 0})
# Mock monitoring counts
count_result = _create_mock_result(scalar_value=10)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Mock file operations
service._path_size = Mock(return_value=1000)
service._file_count = Mock(return_value=5)
service._monitoring_counts = AsyncMock(return_value={'messages': 10, 'errors': 0})
service._binary_storage_stats = AsyncMock(return_value={'count': 5, 'size_bytes': 500})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify
assert 'generated_at' in result
assert 'cleanup_policy' in result
assert 'sections' in result
assert 'database' in result
assert 'cleanup_candidates' in result
async def test_get_storage_analysis_sections(self):
"""Returns all storage sections."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'database': {'use': 'postgresql'}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify - all sections present
sections = {s['key'] for s in result['sections']}
assert 'database' in sections
assert 'logs' in sections
assert 'storage' in sections
assert 'vector_store' in sections
assert 'plugins' in sections
assert 'mcp' in sections
assert 'temp' in sections
async def test_get_storage_analysis_postgresql(self):
"""Handles PostgreSQL database type."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'database': {'use': 'postgresql'}}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': None})
service._expired_uploaded_candidates = AsyncMock(return_value=[])
service._expired_log_candidates = Mock(return_value=[])
# Execute
result = await service.get_storage_analysis()
# Verify
assert result['database']['type'] == 'postgresql'
async def test_get_storage_analysis_with_cleanup_candidates(self):
"""Returns cleanup candidates in analysis."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
ap.task_mgr = None
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
service._path_size = Mock(return_value=0)
service._file_count = Mock(return_value=0)
service._monitoring_counts = AsyncMock(return_value={})
service._binary_storage_stats = AsyncMock(return_value={'count': 0, 'size_bytes': 0})
service._expired_uploaded_candidates = AsyncMock(return_value=[
{'key': 'old_file', 'size_bytes': 100}
])
service._expired_log_candidates = Mock(return_value=[
{'name': 'old_log', 'size_bytes': 50}
])
# Execute
result = await service.get_storage_analysis()
# Verify
assert len(result['cleanup_candidates']['uploaded_files']) == 1
assert len(result['cleanup_candidates']['log_files']) == 1
class TestMaintenanceServiceMonitoringCounts:
"""Tests for _monitoring_counts method."""
async def test_monitoring_counts_returns_counts(self):
"""Returns counts for all monitoring tables."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
count_result = _create_mock_result(scalar_value=42)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Execute
result = await service._monitoring_counts()
# Verify - all table keys present
assert 'messages' in result
assert 'llm_calls' in result
assert 'embedding_calls' in result
assert 'errors' in result
assert 'sessions' in result
assert 'feedback' in result
async def test_monitoring_counts_zero_results(self):
"""Returns zero counts when tables empty."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
count_result = _create_mock_result(scalar_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=count_result)
service = MaintenanceService(ap)
# Execute
result = await service._monitoring_counts()
# Verify - all zero
assert all(v == 0 for v in result.values())
class TestMaintenanceServiceBinaryStorageStats:
"""Tests for _binary_storage_stats method."""
async def test_binary_storage_stats_returns_stats(self):
"""Returns count and size for binary storage."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
# Mock count result
count_result = _create_mock_result(scalar_value=10)
# Mock size result
size_result = _create_mock_result(scalar_value=5000)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return count_result
return size_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MaintenanceService(ap)
# Execute
result = await service._binary_storage_stats()
# Verify
assert result['count'] == 10
assert result['size_bytes'] == 5000
async def test_binary_storage_stats_size_error(self):
"""Handles error when calculating size."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
count_result = _create_mock_result(scalar_value=5)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return count_result
raise Exception('Size calculation error')
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MaintenanceService(ap)
# Execute
result = await service._binary_storage_stats()
# Verify - warning logged, size_bytes None or 0
assert ap.logger.warning.called
assert result['count'] == 5
class TestMaintenanceServicePathSize:
"""Tests for _path_size method."""
def test_path_size_nonexistent_path(self):
"""Returns 0 for nonexistent path."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Execute
result = service._path_size(Path('/nonexistent/path'))
# Verify
assert result == 0
def test_path_size_single_file(self):
"""Returns size for single file."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock file
mock_stat = Mock()
mock_stat.st_size = 100
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=True):
with patch.object(Path, 'stat', return_value=mock_stat):
result = service._path_size(Path('test.txt'))
# Verify
assert result == 100
def test_path_size_directory(self):
"""Returns total size for directory."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock os.walk
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=False):
with patch('os.walk') as mock_walk:
mock_walk.return_value = [
('/test_dir', [], ['file1.txt', 'file2.txt']),
]
# Mock file stat
mock_stat = Mock()
mock_stat.st_size = 50
with patch.object(Path, 'stat', return_value=mock_stat):
result = service._path_size(Path('/test_dir'))
# Verify - 2 files * 50 bytes
assert result == 100
class TestMaintenanceServiceFileCount:
"""Tests for _file_count method."""
def test_file_count_nonexistent_path(self):
"""Returns 0 for nonexistent path."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Execute
result = service._file_count(Path('/nonexistent/path'))
# Verify
assert result == 0
def test_file_count_single_file(self):
"""Returns 1 for single file."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=True):
result = service._file_count(Path('test.txt'))
# Verify
assert result == 1
def test_file_count_directory(self):
"""Returns file count for directory."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'is_file', return_value=False):
with patch('os.walk') as mock_walk:
mock_walk.return_value = [
('/test_dir', [], ['file1.txt', 'file2.txt', 'file3.txt']),
]
result = service._file_count(Path('/test_dir'))
# Verify
assert result == 3
class TestMaintenanceServicePositiveInt:
"""Tests for _positive_int helper method."""
def test_positive_int_valid_value(self):
"""Returns valid positive integer."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(7, 5, 'test_param')
# Verify
assert result == 7
assert not ap.logger.warning.called
def test_positive_int_invalid_string(self):
"""Returns default for invalid string."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int('invalid', 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_invalid_none(self):
"""Returns default for None."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(None, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_negative_value(self):
"""Returns default for negative value."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(-1, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
def test_positive_int_zero_value(self):
"""Returns default for zero value."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
ap.logger.warning = Mock()
service = MaintenanceService(ap)
# Execute
result = service._positive_int(0, 5, 'test_param')
# Verify
assert result == 5
assert ap.logger.warning.called
class TestMaintenanceServiceIsUploadedFileKey:
"""Tests for _is_uploaded_file_key helper method."""
def test_is_uploaded_file_key_valid(self):
"""Returns True for valid upload file key."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - simple filename without path
result = service._is_uploaded_file_key('uploaded_file.txt')
# Verify
assert result is True
def test_is_uploaded_file_key_with_path(self):
"""Returns False for key with path separator."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - key with path
result = service._is_uploaded_file_key('path/to/file.txt')
# Verify
assert result is False
def test_is_uploaded_file_key_plugin_config(self):
"""Returns False for plugin config prefix."""
# Setup
ap = SimpleNamespace()
service = MaintenanceService(ap)
# Execute - plugin config file
result = service._is_uploaded_file_key('plugin_config_some_plugin.json')
# Verify
assert result is False
class TestMaintenanceServiceExpiredLogCandidates:
"""Tests for _expired_log_candidates method."""
def test_expired_log_candidates_nonexistent_dir(self):
"""Returns empty list when logs dir not exists."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=False):
result = service._expired_log_candidates(3)
# Verify
assert result == []
def test_expired_log_candidates_matches_pattern(self):
"""Matches log file pattern correctly."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock directory with log files
old_date = datetime.date.today() - datetime.timedelta(days=10)
old_log_name = f'langbot-{old_date.isoformat()}.log'
recent_log_name = f'langbot-{datetime.date.today().isoformat()}.log'
mock_entry_old = Mock(spec=Path)
mock_entry_old.is_file = Mock(return_value=True)
mock_entry_old.name = old_log_name
mock_stat = Mock()
mock_stat.st_size = 1000
mock_entry_old.stat = Mock(return_value=mock_stat)
mock_entry_recent = Mock(spec=Path)
mock_entry_recent.is_file = Mock(return_value=True)
mock_entry_recent.name = recent_log_name
mock_stat2 = Mock()
mock_stat2.st_size = 500
mock_entry_recent.stat = Mock(return_value=mock_stat2)
# Non-log file
mock_entry_other = Mock(spec=Path)
mock_entry_other.is_file = Mock(return_value=True)
mock_entry_other.name = 'other_file.txt'
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry_old, mock_entry_recent, mock_entry_other]
result = service._expired_log_candidates(3)
# Verify - only old log included
assert len(result) == 1
assert result[0]['name'] == old_log_name
def test_expired_log_candidates_includes_path(self):
"""Includes path when include_paths=True."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
old_date = datetime.date.today() - datetime.timedelta(days=10)
old_log_name = f'langbot-{old_date.isoformat()}.log'
mock_entry = Mock(spec=Path)
mock_entry.is_file = Mock(return_value=True)
mock_entry.name = old_log_name
mock_entry.__str__ = Mock(return_value='/data/logs/' + old_log_name)
mock_stat = Mock()
mock_stat.st_size = 1000
mock_entry.stat = Mock(return_value=mock_stat)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry]
result = service._expired_log_candidates(3, include_paths=True)
# Verify - path included
assert 'path' in result[0]
class TestMaintenanceServiceExpiredLocalUploadCandidates:
"""Tests for _expired_local_upload_candidates method."""
def test_expired_local_upload_candidates_nonexistent_dir(self):
"""Returns empty list when storage dir not exists."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
with patch.object(Path, 'exists', return_value=False):
result = service._expired_local_upload_candidates(7)
# Verify
assert result == []
def test_expired_local_upload_candidates_filters_uploaded(self):
"""Only returns uploaded files matching pattern."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
# Mock _is_uploaded_file_key
service._is_uploaded_file_key = Mock(side_effect=lambda key: 'plugin_config_' not in key and '/' not in key)
# Create mock files - one valid, one plugin config
mock_entry_valid = Mock(spec=Path)
mock_entry_valid.is_file = Mock(return_value=True)
mock_entry_valid.name = 'valid_upload.txt'
mock_stat = Mock()
mock_stat.st_size = 100
mock_stat.st_mtime = 0 # Very old
mock_entry_valid.stat = Mock(return_value=mock_stat)
mock_entry_plugin = Mock(spec=Path)
mock_entry_plugin.is_file = Mock(return_value=True)
mock_entry_plugin.name = 'plugin_config_test.json'
mock_stat2 = Mock()
mock_stat2.st_size = 200
mock_stat2.st_mtime = 0
mock_entry_plugin.stat = Mock(return_value=mock_stat2)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry_valid, mock_entry_plugin]
result = service._expired_local_upload_candidates(7)
# Verify - only valid upload included
assert len(result) == 1
assert result[0]['key'] == 'valid_upload.txt'
def test_expired_local_upload_candidates_includes_path(self):
"""Includes path when include_paths=True."""
# Setup
ap = SimpleNamespace()
ap.logger = SimpleNamespace()
service = MaintenanceService(ap)
service._is_uploaded_file_key = Mock(return_value=True)
mock_entry = Mock(spec=Path)
mock_entry.is_file = Mock(return_value=True)
mock_entry.name = 'old_file.txt'
mock_entry.__str__ = Mock(return_value='/data/storage/old_file.txt')
mock_stat = Mock()
mock_stat.st_size = 100
mock_stat.st_mtime = 0
mock_entry.stat = Mock(return_value=mock_stat)
with patch.object(Path, 'exists', return_value=True):
with patch.object(Path, 'iterdir') as mock_iterdir:
mock_iterdir.return_value = [mock_entry]
result = service._expired_local_upload_candidates(7, include_paths=True)
# Verify - path included
assert 'path' in result[0]

View File

@@ -1,648 +0,0 @@
"""
Unit tests for MCPService.
Tests MCP server CRUD operations including:
- MCP server listing with runtime info
- MCP server creation with limitations
- MCP server update with enable/disable
- MCP server deletion
- MCP server connection testing
Source: src/langbot/pkg/api/http/service/mcp.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, MagicMock
from types import SimpleNamespace
import uuid
from langbot.pkg.api.http.service.mcp import MCPService
from langbot.pkg.entity.persistence.mcp import MCPServer
pytestmark = pytest.mark.asyncio
def _create_mock_mcp_server(
server_uuid: str = None,
name: str = 'Test MCP Server',
enable: bool = True,
mode: str = 'stdio',
extra_args: dict = None,
) -> Mock:
"""Helper to create mock MCPServer entity."""
server = Mock(spec=MCPServer)
server.uuid = server_uuid or str(uuid.uuid4())
server.name = name
server.enable = enable
server.mode = mode
server.extra_args = extra_args or {}
return server
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestMCPServiceGetRuntimeInfo:
"""Tests for get_runtime_info method."""
async def test_get_runtime_info_session_exists(self):
"""Returns runtime info when session exists."""
# Setup
ap = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
mock_session = SimpleNamespace()
mock_session.get_runtime_info_dict = Mock(return_value={'status': 'running', 'tools': 5})
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
service = MCPService(ap)
# Execute
result = await service.get_runtime_info('test-server')
# Verify
assert result is not None
assert result['status'] == 'running'
async def test_get_runtime_info_session_not_exists(self):
"""Returns None when session not exists."""
# Setup
ap = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
service = MCPService(ap)
# Execute
result = await service.get_runtime_info('nonexistent-server')
# Verify
assert result is None
class TestMCPServiceGetMCPServers:
"""Tests for get_mcp_servers method."""
async def test_get_mcp_servers_empty_list(self):
"""Returns empty list when no MCP servers exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
}
)
ap.tool_mgr = None
service = MCPService(ap)
# Execute
result = await service.get_mcp_servers()
# Verify
assert result == []
async def test_get_mcp_servers_returns_serialized_list(self):
"""Returns serialized list of MCP servers."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
server2 = _create_mock_mcp_server(server_uuid='uuid-2', name='Server 2')
mock_result = _create_mock_result([server1, server2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'enable': entity.enable,
'mode': entity.mode,
}
)
ap.tool_mgr = None
service = MCPService(ap)
# Execute
result = await service.get_mcp_servers()
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Server 1'
assert result[1]['name'] == 'Server 2'
async def test_get_mcp_servers_with_runtime_info(self):
"""Returns MCP servers with runtime info when requested."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
server1 = _create_mock_mcp_server(server_uuid='uuid-1', name='Server 1')
mock_result = _create_mock_result([server1])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
}
)
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
service = MCPService(ap)
service.get_runtime_info = AsyncMock(return_value={'status': 'connected'})
# Execute
result = await service.get_mcp_servers(contain_runtime_info=True)
# Verify - runtime info included
assert result[0]['runtime_info'] == {'status': 'connected'}
class TestMCPServiceCreateMCPServer:
"""Tests for create_mcp_server method."""
async def test_create_mcp_server_max_extensions_reached_raises(self):
"""Raises ValueError when max_extensions limit reached."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': 2
}
}
}
ap.plugin_connector = SimpleNamespace()
ap.plugin_connector.list_plugins = AsyncMock(return_value=[Mock(), Mock()]) # 2 plugins
# Mock get_mcp_servers to return 0 servers (2 plugins already)
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={})
ap.tool_mgr = None
service = MCPService(ap)
# Execute & Verify - 2 plugins + new server would exceed limit
with pytest.raises(ValueError, match='Maximum number of extensions'):
await service.create_mcp_server({'name': 'New Server'})
async def test_create_mcp_server_no_limit(self):
"""Creates MCP server without limit when max_extensions=-1."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_extensions': -1 # No limit
}
}
}
ap.tool_mgr = None
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
service = MCPService(ap)
# Execute
server_uuid = await service.create_mcp_server({'name': 'New Server'})
# Verify
assert server_uuid is not None
assert len(server_uuid) == 36 # UUID format
async def test_create_mcp_server_loads_server(self):
"""Loads server into tool_mgr when enabled."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
# Create mock server entity
server_entity = _create_mock_mcp_server(server_uuid='new-uuid', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result([]) # Empty list for limit check
elif call_count == 2:
return Mock() # Insert
return _create_mock_result(first_item=server_entity) # Select created
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Server', 'enable': True}
)
service = MCPService(ap)
# Execute
await service.create_mcp_server({'name': 'New Server', 'enable': True})
# Verify - host_mcp_server was called
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
async def test_create_mcp_server_disabled_no_load(self):
"""Does not load server when disabled."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_extensions': -1}}}
ap.tool_mgr = None
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={'uuid': 'new-uuid'})
service = MCPService(ap)
# Execute with enable=False
server_uuid = await service.create_mcp_server({'name': 'New Server', 'enable': False})
# Verify - no tool_mgr load attempt
assert server_uuid is not None
class TestMCPServiceGetMCPServerByName:
"""Tests for get_mcp_server_by_name method."""
async def test_get_mcp_server_by_name_found(self):
"""Returns MCP server when found by name."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
server = _create_mock_mcp_server(name='Found Server')
mock_result = _create_mock_result(first_item=server)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'name': 'Found Server',
'runtime_info': None,
}
)
ap.tool_mgr = None
service = MCPService(ap)
service.get_runtime_info = AsyncMock(return_value=None)
# Execute
result = await service.get_mcp_server_by_name('Found Server')
# Verify
assert result is not None
assert result['name'] == 'Found Server'
async def test_get_mcp_server_by_name_not_found(self):
"""Returns None when MCP server not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = MCPService(ap)
# Execute
result = await service.get_mcp_server_by_name('Nonexistent Server')
# Verify
assert result is None
class TestMCPServiceUpdateMCPServer:
"""Tests for update_mcp_server method."""
async def test_update_mcp_server_disable_enabled_server(self):
"""Removes server when disabling previously enabled server."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=old_server)
return Mock() # Update
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MCPService(ap)
# Execute - disable server
await service.update_mcp_server('test-uuid', {'enable': False})
# Verify - server was removed
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once()
async def test_update_mcp_server_enable_disabled_server(self):
"""Loads server when enabling previously disabled server."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {}
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
old_server = _create_mock_mcp_server(name='Old Server', enable=False)
updated_server = _create_mock_mcp_server(name='Old Server', enable=True)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=old_server)
elif call_count == 2:
return Mock() # Update
return _create_mock_result(first_item=updated_server) # Select updated
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
)
service = MCPService(ap)
# Execute - enable server
await service.update_mcp_server('test-uuid', {'enable': True})
# Verify - server was loaded
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
async def test_update_mcp_server_update_enabled_server(self):
"""Removes and reloads server when updating enabled server."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {'Old Server': Mock()}
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
ap.tool_mgr.mcp_tool_loader.host_mcp_server = AsyncMock()
ap.tool_mgr.mcp_tool_loader._hosted_mcp_tasks = []
old_server = _create_mock_mcp_server(name='Old Server', enable=True)
# Mock for: first select -> update -> second select (for updated server)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
# All selects return the server
return _create_mock_result(first_item=old_server)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'test-uuid', 'name': 'Old Server', 'enable': True}
)
service = MCPService(ap)
# Execute - update enabled server (keep enabled, update extra_args)
await service.update_mcp_server('test-uuid', {'enable': True, 'extra_args': {'new': 'args'}})
# Verify - remove and reload
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Old Server')
ap.tool_mgr.mcp_tool_loader.host_mcp_server.assert_called_once()
async def test_update_mcp_server_no_tool_mgr(self):
"""Updates persistence without tool_mgr operations."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Set mcp_tool_loader to None, not tool_mgr itself
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = None
old_server = _create_mock_mcp_server(name='Server', enable=True)
# Mock execute for select and update
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=old_server)
return Mock() # Update
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MCPService(ap)
# Execute - should not raise
await service.update_mcp_server('test-uuid', {'name': 'New Name'})
# Verify - persistence was called
assert ap.persistence_mgr.execute_async.call_count >= 2
class TestMCPServiceDeleteMCPServer:
"""Tests for delete_mcp_server method."""
async def test_delete_mcp_server_calls_remove_and_delete(self):
"""Calls both persistence delete and tool_mgr remove."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {'Server to Delete': Mock()}
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
server = _create_mock_mcp_server(name='Server to Delete')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=server)
return Mock() # Delete
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MCPService(ap)
# Execute
await service.delete_mcp_server('test-uuid')
# Verify
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_called_once_with('Server to Delete')
ap.persistence_mgr.execute_async.assert_called()
async def test_delete_mcp_server_not_in_sessions(self):
"""Does not attempt remove if server not in sessions."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {} # Server not in sessions
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
server = _create_mock_mcp_server(name='Not in Sessions')
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=server)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MCPService(ap)
# Execute
await service.delete_mcp_server('test-uuid')
# Verify - remove not called (server not in sessions)
ap.tool_mgr.mcp_tool_loader.remove_mcp_server.assert_not_called()
async def test_delete_mcp_server_nonexistent_uuid(self):
"""Delete operation completes even for nonexistent UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.sessions = {}
ap.tool_mgr.mcp_tool_loader.remove_mcp_server = AsyncMock()
# No server found
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=None)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = MCPService(ap)
# Execute - should not raise
await service.delete_mcp_server('nonexistent-uuid')
# Verify - delete was called regardless
ap.persistence_mgr.execute_async.assert_called()
class TestMCPServiceTestMCPServer:
"""Tests for test_mcp_server method."""
async def test_test_mcp_server_existing_server(self):
"""Tests existing MCP server connection."""
# Setup
ap = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
from langbot.pkg.provider.tools.loaders.mcp import MCPSessionStatus
mock_session = MagicMock()
mock_session.status = MCPSessionStatus.ERROR
mock_session.start = AsyncMock()
mock_session.refresh = AsyncMock()
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=123)
)
service = MCPService(ap)
# Execute
task_id = await service.test_mcp_server('existing-server', {})
# Verify - returns task ID
assert task_id == 123
async def test_test_mcp_server_not_found_raises(self):
"""Raises ValueError when server not found."""
# Setup
ap = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader.get_session = Mock(return_value=None)
service = MCPService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Server not found'):
await service.test_mcp_server('nonexistent-server', {})
async def test_test_mcp_server_new_server(self):
"""Tests new MCP server with underscore name."""
# Setup
ap = SimpleNamespace()
ap.tool_mgr = SimpleNamespace()
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
mock_session = MagicMock()
mock_session.start = AsyncMock()
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
ap.task_mgr = SimpleNamespace()
ap.task_mgr.create_user_task = Mock(
return_value=SimpleNamespace(id=456)
)
service = MCPService(ap)
# Execute with '_' name (new server)
task_id = await service.test_mcp_server('_', {'name': 'New Server'})
# Verify - load_mcp_server called
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
assert task_id == 456

View File

@@ -1,964 +0,0 @@
"""
Unit tests for LLMModelsService, EmbeddingModelsService, and RerankModelsService.
Tests model management operations including:
- Model CRUD operations
- Model with provider info
- Provider auto-creation on model create/update
- Runtime model loading/unloading
- Model deletion
Source: src/langbot/pkg/api/http/service/model.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.api.http.service.model import (
LLMModelsService,
EmbeddingModelsService,
RerankModelsService,
_parse_provider_api_keys,
_runtime_model_data,
)
from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, RerankModel, ModelProvider
pytestmark = pytest.mark.asyncio
def _create_mock_llm_model(
model_uuid: str = 'llm-uuid',
name: str = 'Test LLM',
provider_uuid: str = 'provider-uuid',
abilities: list = None,
extra_args: dict = None,
) -> Mock:
"""Helper to create mock LLMModel entity."""
model = Mock(spec=LLMModel)
model.uuid = model_uuid
model.name = name
model.provider_uuid = provider_uuid
model.abilities = abilities or []
model.extra_args = extra_args or {}
return model
def _create_mock_embedding_model(
model_uuid: str = 'embedding-uuid',
name: str = 'Test Embedding',
provider_uuid: str = 'provider-uuid',
) -> Mock:
"""Helper to create mock EmbeddingModel entity."""
model = Mock(spec=EmbeddingModel)
model.uuid = model_uuid
model.name = name
model.provider_uuid = provider_uuid
model.extra_args = {}
return model
def _create_mock_rerank_model(
model_uuid: str = 'rerank-uuid',
name: str = 'Test Rerank',
provider_uuid: str = 'provider-uuid',
) -> Mock:
"""Helper to create mock RerankModel entity."""
model = Mock(spec=RerankModel)
model.uuid = model_uuid
model.name = name
model.provider_uuid = provider_uuid
model.extra_args = {}
return model
def _create_mock_provider(
provider_uuid: str = 'provider-uuid',
name: str = 'Test Provider',
api_keys: list = None,
) -> Mock:
"""Helper to create mock ModelProvider entity."""
provider = Mock(spec=ModelProvider)
provider.uuid = provider_uuid
provider.name = name
provider.requester = 'openai'
provider.base_url = 'https://api.openai.com'
provider.api_keys = api_keys or ['key']
return provider
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestParseProviderApiKeys:
"""Tests for _parse_provider_api_keys helper function."""
def test_parse_valid_json_string(self):
"""Parses valid JSON string to list."""
provider_dict = {'api_keys': '["key1", "key2"]'}
result = _parse_provider_api_keys(provider_dict)
assert result['api_keys'] == ['key1', 'key2']
def test_parse_invalid_json_returns_empty(self):
"""Returns empty list for invalid JSON."""
provider_dict = {'api_keys': 'invalid json'}
result = _parse_provider_api_keys(provider_dict)
assert result['api_keys'] == []
def test_parse_already_list(self):
"""Returns unchanged if already a list."""
provider_dict = {'api_keys': ['key1', 'key2']}
result = _parse_provider_api_keys(provider_dict)
assert result['api_keys'] == ['key1', 'key2']
def test_parse_missing_key(self):
"""Handles missing api_keys key."""
provider_dict = {'name': 'Provider'}
result = _parse_provider_api_keys(provider_dict)
assert 'api_keys' not in result
class TestRuntimeModelData:
"""Tests for _runtime_model_data helper function."""
def test_runtime_data_preserves_uuid(self):
"""Preserves UUID in runtime data."""
update_payload = {'name': 'Updated', 'provider_uuid': 'provider'}
result = _runtime_model_data('model-uuid', update_payload)
assert result['uuid'] == 'model-uuid'
assert result['name'] == 'Updated'
def test_runtime_data_copies_all_fields(self):
"""Copies all fields from payload."""
update_payload = {
'name': 'Model',
'provider_uuid': 'provider',
'abilities': ['vision'],
'extra_args': {'temp': 0.7},
}
result = _runtime_model_data('uuid', update_payload)
assert result['abilities'] == ['vision']
assert result['extra_args'] == {'temp': 0.7}
class TestLLMModelsServiceGetLLMModels:
"""Tests for LLMModelsService.get_llm_models method."""
async def test_get_llm_models_empty_list(self):
"""Returns empty list when no models exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
mock_provider_result = _create_mock_result([])
call_count = 0
async def mock_execute(query):
return mock_result if call_count == 0 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': entity.provider_uuid,
}
)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_models()
# Verify
assert result == []
async def test_get_llm_models_with_provider_info(self):
"""Returns models with provider info."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model()
provider = _create_mock_provider()
mock_model_result = _create_mock_result([model])
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
}
)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_models()
# Verify
assert len(result) == 1
assert result[0]['name'] == 'Test LLM'
async def test_get_llm_models_hide_secret_keys(self):
"""Hides secret API keys when include_secret=False."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model()
provider = _create_mock_provider(api_keys=['secret-key-1', 'secret-key-2'])
mock_model_result = _create_mock_result([model])
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
}
)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_models(include_secret=False)
# Verify - keys should be masked
assert result[0]['provider']['api_keys'] == ['***', '***']
class TestLLMModelsServiceGetLLMModel:
"""Tests for LLMModelsService.get_llm_model method."""
async def test_get_llm_model_found(self):
"""Returns model when found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model(model_uuid='found-uuid')
provider = _create_mock_provider()
mock_model_result = _create_mock_result([], first_item=model)
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-uuid',
'name': 'Test LLM',
'provider_uuid': 'provider-uuid',
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
}
)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_model('found-uuid')
# Verify
assert result is not None
assert result['uuid'] == 'found-uuid'
async def test_get_llm_model_not_found(self):
"""Returns None when model not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([], first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_model('nonexistent-uuid')
# Verify
assert result is None
class TestLLMModelsServiceGetLLMModelsByProvider:
"""Tests for LLMModelsService.get_llm_models_by_provider method."""
async def test_get_models_by_provider_uuid(self):
"""Returns models for specific provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model1 = _create_mock_llm_model(model_uuid='model-1', provider_uuid='target-provider')
model2 = _create_mock_llm_model(model_uuid='model-2', provider_uuid='target-provider')
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'model-1', 'name': 'Model 1'}
)
service = LLMModelsService(ap)
# Execute
result = await service.get_llm_models_by_provider('target-provider')
# Verify
assert len(result) == 2
class TestLLMModelsServiceCreateLLMModel:
"""Tests for LLMModelsService.create_llm_model method."""
async def test_create_llm_model_generates_uuid(self):
"""Creates LLM model with generated UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.pipeline_service = SimpleNamespace()
ap.pipeline_service.update_pipeline = AsyncMock()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'name': 'New LLM',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
})
# Verify
assert model_uuid is not None
assert len(model_uuid) == 36 # UUID format
async def test_create_llm_model_preserve_uuid(self):
"""Creates LLM model preserving provided UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.pipeline_service = SimpleNamespace()
ap.pipeline_service.update_pipeline = AsyncMock()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
# Execute
model_uuid = await service.create_llm_model({
'uuid': 'preserved-uuid',
'name': 'Preserved UUID Model',
'provider_uuid': 'provider-uuid',
'abilities': [],
'extra_args': {},
}, preserve_uuid=True)
# Verify
assert model_uuid == 'preserved-uuid'
async def test_create_llm_model_provider_not_found_raises_error(self):
"""Raises Exception when provider not found in runtime."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {} # Empty - no provider
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_llm_model({
'name': 'No Provider Model',
'provider_uuid': 'nonexistent-provider',
'abilities': [],
'extra_args': {},
})
async def test_create_llm_model_with_provider_data(self):
"""Creates provider when provider data provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.provider_service = SimpleNamespace()
ap.provider_service.find_or_create_provider = AsyncMock(return_value='new-provider-uuid')
ap.pipeline_service = SimpleNamespace()
ap.pipeline_service.update_pipeline = AsyncMock()
# Create runtime provider
runtime_provider = Mock()
ap.model_mgr.provider_dict['new-provider-uuid'] = runtime_provider
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
# Execute - with provider data (no UUID)
result_uuid = await service.create_llm_model({
'name': 'Model with New Provider',
'provider': {
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
},
'abilities': [],
'extra_args': {},
})
# Verify - provider_service was called and UUID generated
ap.provider_service.find_or_create_provider.assert_called_once()
assert result_uuid is not None
class TestLLMModelsServiceUpdateLLMModel:
"""Tests for LLMModelsService.update_llm_model method."""
async def test_update_llm_model_removes_uuid_from_data(self):
"""Removes uuid from update data before persisting."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.remove_llm_model = AsyncMock()
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.persistence_mgr.execute_async = AsyncMock()
service = LLMModelsService(ap)
# Execute
await service.update_llm_model('existing-uuid', {
'uuid': 'should-be-removed',
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
})
# Verify - remove and load called
ap.model_mgr.remove_llm_model.assert_called_once_with('existing-uuid')
async def test_update_llm_model_provider_not_found_raises_error(self):
"""Raises Exception when provider not found after update."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {} # Empty
ap.model_mgr.remove_llm_model = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = LLMModelsService(ap)
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.update_llm_model('model-uuid', {
'name': 'Update',
'provider_uuid': 'nonexistent-provider',
})
class TestLLMModelsServiceDeleteLLMModel:
"""Tests for LLMModelsService.delete_llm_model method."""
async def test_delete_llm_model_success(self):
"""Deletes LLM model successfully."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.remove_llm_model = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = LLMModelsService(ap)
# Execute
await service.delete_llm_model('delete-uuid')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
ap.model_mgr.remove_llm_model.assert_called_once_with('delete-uuid')
class TestEmbeddingModelsServiceGetEmbeddingModels:
"""Tests for EmbeddingModelsService.get_embedding_models method."""
async def test_get_embedding_models_empty_list(self):
"""Returns empty list when no models exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'embedding-uuid', 'name': 'Test'}
)
service = EmbeddingModelsService(ap)
# Execute
result = await service.get_embedding_models()
# Verify
assert result == []
async def test_get_embedding_models_with_provider(self):
"""Returns embedding models with provider info."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_embedding_model()
provider = _create_mock_provider()
mock_model_result = _create_mock_result([model])
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': getattr(entity, 'provider_uuid', None),
'api_keys': getattr(entity, 'api_keys', ['key']),
}
)
service = EmbeddingModelsService(ap)
# Execute
result = await service.get_embedding_models()
# Verify
assert len(result) == 1
class TestEmbeddingModelsServiceGetEmbeddingModel:
"""Tests for EmbeddingModelsService.get_embedding_model method."""
async def test_get_embedding_model_found(self):
"""Returns embedding model when found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_embedding_model(model_uuid='found-embedding')
provider = _create_mock_provider()
mock_model_result = _create_mock_result([], first_item=model)
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-embedding',
'name': 'Found Embedding',
'provider': {'uuid': 'provider-uuid'},
}
)
service = EmbeddingModelsService(ap)
# Execute
result = await service.get_embedding_model('found-embedding')
# Verify
assert result is not None
async def test_get_embedding_model_not_found(self):
"""Returns None when model not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([], first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = EmbeddingModelsService(ap)
# Execute
result = await service.get_embedding_model('nonexistent-embedding')
# Verify
assert result is None
class TestEmbeddingModelsServiceCreateEmbeddingModel:
"""Tests for EmbeddingModelsService.create_embedding_model method."""
async def test_create_embedding_model_success(self):
"""Creates embedding model successfully."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.embedding_models = []
ap.model_mgr.load_embedding_model_with_provider = AsyncMock(return_value=Mock())
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = EmbeddingModelsService(ap)
# Execute
model_uuid = await service.create_embedding_model({
'name': 'New Embedding',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
# Verify
assert model_uuid is not None
assert len(model_uuid) == 36
async def test_create_embedding_model_provider_not_found_raises(self):
"""Raises Exception when provider not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {} # Empty
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = EmbeddingModelsService(ap)
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_embedding_model({
'name': 'No Provider Embedding',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
class TestEmbeddingModelsServiceDeleteEmbeddingModel:
"""Tests for EmbeddingModelsService.delete_embedding_model method."""
async def test_delete_embedding_model_success(self):
"""Deletes embedding model successfully."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.remove_embedding_model = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = EmbeddingModelsService(ap)
# Execute
await service.delete_embedding_model('delete-embedding-uuid')
# Verify
ap.model_mgr.remove_embedding_model.assert_called_once()
class TestRerankModelsServiceGetRerankModels:
"""Tests for RerankModelsService.get_rerank_models method."""
async def test_get_rerank_models_empty_list(self):
"""Returns empty list when no models exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = RerankModelsService(ap)
# Execute
result = await service.get_rerank_models()
# Verify
assert result == []
async def test_get_rerank_models_with_provider(self):
"""Returns rerank models with provider info."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_rerank_model()
provider = _create_mock_provider()
mock_model_result = _create_mock_result([model])
mock_provider_result = _create_mock_result([provider])
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': getattr(entity, 'provider_uuid', None),
'api_keys': getattr(entity, 'api_keys', ['key']),
}
)
service = RerankModelsService(ap)
# Execute
result = await service.get_rerank_models()
# Verify
assert len(result) == 1
class TestRerankModelsServiceGetRerankModel:
"""Tests for RerankModelsService.get_rerank_model method."""
async def test_get_rerank_model_found(self):
"""Returns rerank model when found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_rerank_model(model_uuid='found-rerank')
provider = _create_mock_provider()
mock_model_result = _create_mock_result([], first_item=model)
mock_provider_result = _create_mock_result([], first_item=provider)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
return mock_model_result if call_count == 1 else mock_provider_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-rerank',
'name': 'Found Rerank',
'provider': {'uuid': 'provider-uuid'},
}
)
service = RerankModelsService(ap)
# Execute
result = await service.get_rerank_model('found-rerank')
# Verify
assert result is not None
async def test_get_rerank_model_not_found(self):
"""Returns None when model not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([], first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = RerankModelsService(ap)
# Execute
result = await service.get_rerank_model('nonexistent-rerank')
# Verify
assert result is None
class TestRerankModelsServiceCreateRerankModel:
"""Tests for RerankModelsService.create_rerank_model method."""
async def test_create_rerank_model_success(self):
"""Creates rerank model successfully."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.rerank_models = []
ap.model_mgr.load_rerank_model_with_provider = AsyncMock(return_value=Mock())
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = RerankModelsService(ap)
# Execute
model_uuid = await service.create_rerank_model({
'name': 'New Rerank',
'provider_uuid': 'provider-uuid',
'extra_args': {},
})
# Verify
assert model_uuid is not None
async def test_create_rerank_model_provider_not_found_raises(self):
"""Raises Exception when provider not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = RerankModelsService(ap)
# Execute & Verify
with pytest.raises(Exception, match='provider not found'):
await service.create_rerank_model({
'name': 'No Provider Rerank',
'provider_uuid': 'nonexistent',
'extra_args': {},
})
class TestRerankModelsServiceDeleteRerankModel:
"""Tests for RerankModelsService.delete_rerank_model method."""
async def test_delete_rerank_model_success(self):
"""Deletes rerank model successfully."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.remove_rerank_model = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = RerankModelsService(ap)
# Execute
await service.delete_rerank_model('delete-rerank-uuid')
# Verify
ap.model_mgr.remove_rerank_model.assert_called_once()
class TestEmbeddingModelsServiceGetEmbeddingModelsByProvider:
"""Tests for EmbeddingModelsService.get_embedding_models_by_provider method."""
async def test_get_embedding_models_by_provider_uuid(self):
"""Returns embedding models for specific provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model1 = _create_mock_embedding_model(model_uuid='emb-1', provider_uuid='provider-uuid')
model2 = _create_mock_embedding_model(model_uuid='emb-2', provider_uuid='provider-uuid')
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'emb-1', 'name': 'Embedding 1'}
)
service = EmbeddingModelsService(ap)
# Execute
result = await service.get_embedding_models_by_provider('provider-uuid')
# Verify
assert len(result) == 2
class TestRerankModelsServiceGetRerankModelsByProvider:
"""Tests for RerankModelsService.get_rerank_models_by_provider method."""
async def test_get_rerank_models_by_provider_uuid(self):
"""Returns rerank models for specific provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model1 = _create_mock_rerank_model(model_uuid='rerank-1', provider_uuid='provider-uuid')
model2 = _create_mock_rerank_model(model_uuid='rerank-2', provider_uuid='provider-uuid')
mock_result = _create_mock_result([model1, model2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'rerank-1', 'name': 'Rerank 1'}
)
service = RerankModelsService(ap)
# Execute
result = await service.get_rerank_models_by_provider('provider-uuid')
# Verify
assert len(result) == 2

View File

@@ -1,831 +0,0 @@
"""
Unit tests for PipelineService.
Tests pipeline CRUD operations including:
- Pipeline listing with sorting
- Pipeline creation with default config
- Pipeline update with bot sync
- Pipeline copy functionality
- Extensions preferences management
Source: src/langbot/pkg/api/http/service/pipeline.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch, mock_open
from types import SimpleNamespace
import uuid
import json
from langbot.pkg.api.http.service.pipeline import PipelineService, default_stage_order
from langbot.pkg.entity.persistence.pipeline import LegacyPipeline
pytestmark = pytest.mark.asyncio
def _create_mock_pipeline(
pipeline_uuid: str = None,
name: str = 'Test Pipeline',
description: str = 'Test Description',
is_default: bool = False,
stages: list = None,
config: dict = None,
extensions_preferences: dict = None,
) -> Mock:
"""Helper to create mock LegacyPipeline entity."""
pipeline = Mock(spec=LegacyPipeline)
pipeline.uuid = pipeline_uuid or str(uuid.uuid4())
pipeline.name = name
pipeline.description = description
pipeline.emoji = '⚙️'
pipeline.is_default = is_default
pipeline.for_version = '1.0.0'
pipeline.stages = stages or default_stage_order.copy()
pipeline.config = config or {}
pipeline.extensions_preferences = extensions_preferences or {
'enable_all_plugins': True,
'enable_all_mcp_servers': True,
'plugins': [],
'mcp_servers': [],
}
return pipeline
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestPipelineServiceGetPipelineMetadata:
"""Tests for get_pipeline_metadata method."""
async def test_get_pipeline_metadata_returns_list(self):
"""Returns list of pipeline metadata configs."""
# Setup
ap = SimpleNamespace()
ap.pipeline_config_meta_trigger = {'trigger': {}}
ap.pipeline_config_meta_safety = {'safety': {}}
ap.pipeline_config_meta_ai = {'ai': {}}
ap.pipeline_config_meta_output = {'output': {}}
service = PipelineService(ap)
# Execute
result = await service.get_pipeline_metadata()
# Verify
assert len(result) == 4
assert 'trigger' in result[0]
assert 'safety' in result[1]
assert 'ai' in result[2]
assert 'output' in result[3]
class TestPipelineServiceGetPipelines:
"""Tests for get_pipelines method."""
async def test_get_pipelines_empty_list(self):
"""Returns empty list when no pipelines exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
}
)
service = PipelineService(ap)
# Execute
result = await service.get_pipelines()
# Verify
assert result == []
async def test_get_pipelines_returns_sorted_by_created_at_desc(self):
"""Returns pipelines sorted by created_at descending by default."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
pipeline1 = _create_mock_pipeline(pipeline_uuid='uuid-1', name='Pipeline 1')
pipeline2 = _create_mock_pipeline(pipeline_uuid='uuid-2', name='Pipeline 2')
mock_result = _create_mock_result([pipeline1, pipeline2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
}
)
service = PipelineService(ap)
# Execute
result = await service.get_pipelines()
# Verify
assert len(result) == 2
ap.persistence_mgr.execute_async.assert_called_once()
async def test_get_pipelines_sort_by_updated_at_asc(self):
"""Returns pipelines sorted by updated_at ascending."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={})
service = PipelineService(ap)
# Execute
await service.get_pipelines(sort_by='updated_at', sort_order='ASC')
# Verify - execute was called with sort parameters
ap.persistence_mgr.execute_async.assert_called_once()
class TestPipelineServiceGetPipeline:
"""Tests for get_pipeline method."""
async def test_get_pipeline_by_uuid_found(self):
"""Returns pipeline when found by UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
pipeline = _create_mock_pipeline(pipeline_uuid='test-uuid', name='Found Pipeline')
mock_result = _create_mock_result(first_item=pipeline)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'name': 'Found Pipeline',
'stages': default_stage_order,
}
)
service = PipelineService(ap)
# Execute
result = await service.get_pipeline('test-uuid')
# Verify
assert result is not None
assert result['uuid'] == 'test-uuid'
assert result['name'] == 'Found Pipeline'
async def test_get_pipeline_by_uuid_not_found(self):
"""Returns None when pipeline not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = PipelineService(ap)
# Execute
result = await service.get_pipeline('nonexistent-uuid')
# Verify
assert result is None
class TestPipelineServiceCreatePipeline:
"""Tests for create_pipeline method."""
async def test_create_pipeline_max_limit_reached_raises(self):
"""Raises ValueError when max_pipelines limit reached."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
mock_result = _create_mock_result([_create_mock_pipeline(), _create_mock_pipeline()])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'uuid-1', 'name': 'Pipeline 1'}
)
service = PipelineService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'):
await service.create_pipeline({'name': 'New Pipeline'})
async def test_create_pipeline_no_limit(self):
"""Creates pipeline without limit when max_pipelines=-1."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
service = PipelineService(ap)
# Override get_pipelines to return empty list (no limit check issue)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'})
# Mock persistence for insert
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'New Pipeline'}
)
# Mock the file read for default config - patch at the utils module level
default_config = {'trigger': {}, 'safety': {}, 'ai': {}, 'output': {}}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
bot_uuid = await service.create_pipeline({'name': 'New Pipeline'})
# Verify
assert bot_uuid is not None
assert len(bot_uuid) == 36 # UUID format
async def test_create_pipeline_as_default(self):
"""Creates pipeline with is_default=True."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True})
ap.persistence_mgr.execute_async = AsyncMock()
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'new-uuid', 'name': 'Default Pipeline', 'is_default': True}
)
# Mock the file read
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
await service.create_pipeline({'name': 'Default Pipeline'}, default=True)
# Verify - execute was called
ap.persistence_mgr.execute_async.assert_called()
async def test_create_pipeline_sets_default_extensions_preferences(self):
"""Sets default extensions_preferences when not provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
service.get_pipeline = AsyncMock(return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
})
insert_params = []
async def mock_execute(query):
params = query.compile().params
if 'extensions_preferences' in params:
insert_params.append(params)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'new-uuid',
'extensions_preferences': {},
}
)
default_config = {}
with patch('builtins.open', mock_open(read_data=json.dumps(default_config))):
with patch('langbot.pkg.utils.paths.get_resource_path', return_value='templates/default-pipeline-config.json'):
await service.create_pipeline({'name': 'New Pipeline'})
assert len(insert_params) == 1
assert insert_params[0]['extensions_preferences'] == {
'enable_all_plugins': True,
'enable_all_mcp_servers': True,
'plugins': [],
'mcp_servers': [],
}
class _MockResultWithBots:
"""Helper class to mock SQLAlchemy result with iterable .all() method."""
def __init__(self, bots_list):
self._bots_list = bots_list
def all(self):
return self._bots_list
def first(self):
return self._bots_list[0] if self._bots_list else None
class TestPipelineServiceUpdatePipeline:
"""Tests for update_pipeline method."""
async def test_update_pipeline_removes_protected_fields(self):
"""Does not persist protected fields from update data."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
ap.bot_service = None # No bot_service when not updating name
ap.persistence_mgr.execute_async = AsyncMock()
service = PipelineService(ap)
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'Updated'})
# Execute with protected fields - no name change, so no bot sync
pipeline_data = {
'uuid': 'should-be-removed',
'for_version': 'should-be-removed',
'stages': ['should-be-removed'],
'is_default': True,
'description': 'New description', # Not name change, so no bot_service needed
}
await service.update_pipeline('test-uuid', pipeline_data)
update_params = ap.persistence_mgr.execute_async.await_args_list[0].args[0].compile().params
assert update_params['description'] == 'New description'
assert 'should-be-removed' not in update_params.values()
assert ['should-be-removed'] not in update_params.values()
assert not any(value is True for value in update_params.values())
async def test_update_pipeline_syncs_bot_names(self):
"""Updates bot use_pipeline_name when pipeline name changes."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.sess_mgr = SimpleNamespace()
ap.sess_mgr.session_list = []
ap.bot_service = SimpleNamespace()
ap.bot_service.update_bot = AsyncMock()
# Create proper mock Bot entities with uuid attribute
mock_bot1 = Mock()
mock_bot1.uuid = 'bot-uuid-1'
mock_bot2 = Mock()
mock_bot2.uuid = 'bot-uuid-2'
# Create bot list
bot_list = [mock_bot1, mock_bot2]
# Create mock result using helper class
bot_result = _MockResultWithBots(bot_list)
# The order of calls in update_pipeline:
# 1. UPDATE (line 125) - returns Mock (no result needed)
# 2. SELECT bots (line 136) - returns bot_result with .all()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call is the UPDATE - just return a Mock
return Mock()
elif call_count == 2:
# Second call is the SELECT bots - return proper result
return bot_result
return Mock() # Any additional calls
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(return_value={})
service = PipelineService(ap)
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid', 'name': 'New Name'})
# Execute with name change
await service.update_pipeline('test-uuid', {'name': 'New Name'})
# Verify - bot_service.update_bot was called for each bot
assert ap.bot_service.update_bot.call_count == 2
async def test_update_pipeline_clears_conversations(self):
"""Clears session conversations using this pipeline."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.sess_mgr = SimpleNamespace()
# Mock session with conversation using this pipeline
session = SimpleNamespace()
session.using_conversation = SimpleNamespace()
session.using_conversation.pipeline_uuid = 'test-uuid'
ap.sess_mgr.session_list = [session]
ap.bot_service = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = PipelineService(ap)
service.get_pipeline = AsyncMock(return_value={'uuid': 'test-uuid'})
# Execute
await service.update_pipeline('test-uuid', {'description': 'Updated'})
# Verify - conversation was cleared
assert session.using_conversation is None
class TestPipelineServiceDeletePipeline:
"""Tests for delete_pipeline method."""
async def test_delete_pipeline_calls_remove_and_delete(self):
"""Calls both pipeline_mgr.remove_pipeline and persistence delete."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
service = PipelineService(ap)
# Execute
await service.delete_pipeline('test-uuid')
# Verify
ap.pipeline_mgr.remove_pipeline.assert_called_once_with('test-uuid')
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_pipeline_nonexistent_uuid(self):
"""Delete operation completes even for nonexistent UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
service = PipelineService(ap)
# Execute - should not raise
await service.delete_pipeline('nonexistent-uuid')
# Verify
ap.pipeline_mgr.remove_pipeline.assert_called_once()
class TestPipelineServiceCopyPipeline:
"""Tests for copy_pipeline method."""
async def test_copy_pipeline_max_limit_reached_raises(self):
"""Raises ValueError when max_pipelines limit reached."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'system': {
'limitation': {
'max_pipelines': 2
}
}
}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
service = PipelineService(ap)
# Mock get_pipelines to return 2 pipelines
service.get_pipelines = AsyncMock(return_value=[
{'uuid': 'uuid-1', 'name': 'Pipeline 1'},
{'uuid': 'uuid-2', 'name': 'Pipeline 2'},
])
# Execute & Verify
with pytest.raises(ValueError, match='Maximum number of pipelines'):
await service.copy_pipeline('original-uuid')
async def test_copy_pipeline_not_found_raises(self):
"""Raises ValueError when original pipeline not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
ap.persistence_mgr.execute_async = AsyncMock(
return_value=_create_mock_result(first_item=None) # Original not found
)
ap.persistence_mgr.serialize_model = Mock(return_value={})
# Execute & Verify
with pytest.raises(ValueError, match='Pipeline original-uuid not found'):
await service.copy_pipeline('original-uuid')
async def test_copy_pipeline_creates_copy(self):
"""Creates a copy with (Copy) suffix."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
original = _create_mock_pipeline(
pipeline_uuid='original-uuid',
name='Original Pipeline',
description='Original description',
stages=['Stage1', 'Stage2'],
config={'key': 'value'},
extensions_preferences={'enable_all_plugins': False, 'plugins': ['plugin1']},
)
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[]) # No limit check issue
# Mock persistence - get original, then insert, then get new
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'new-copy-uuid',
'name': 'Original Pipeline (Copy)',
}
)
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'new-copy-uuid',
'name': 'Original Pipeline (Copy)',
}
)
# Execute
new_uuid = await service.copy_pipeline('original-uuid')
# Verify
assert new_uuid is not None
assert len(new_uuid) == 36 # UUID format
async def test_copy_pipeline_is_not_default(self):
"""Copy is never set as default."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'limitation': {'max_pipelines': -1}}}
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.load_pipeline = AsyncMock()
ap.ver_mgr = SimpleNamespace()
ap.ver_mgr.get_current_version = Mock(return_value='1.0.0')
# Original is default
original = _create_mock_pipeline(
pipeline_uuid='original-uuid',
name='Default Pipeline',
is_default=True,
)
service = PipelineService(ap)
service.get_pipelines = AsyncMock(return_value=[])
ap.persistence_mgr.execute_async = AsyncMock(return_value=_create_mock_result(first_item=original))
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'copy-uuid', 'is_default': False}
)
service.get_pipeline = AsyncMock(return_value={'uuid': 'copy-uuid', 'is_default': False})
# Execute
await service.copy_pipeline('original-uuid')
# Verify - pipeline_mgr.load_pipeline called (copy created)
ap.pipeline_mgr.load_pipeline.assert_called_once()
class TestPipelineServiceUpdatePipelineExtensions:
"""Tests for update_pipeline_extensions method."""
async def test_update_extensions_pipeline_not_found_raises(self):
"""Raises ValueError when pipeline not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = PipelineService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Pipeline nonexistent-uuid not found'):
await service.update_pipeline_extensions('nonexistent-uuid', [])
async def test_update_extensions_sets_plugins(self):
"""Updates plugins in extensions_preferences."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(
extensions_preferences={'enable_all_plugins': True, 'plugins': []}
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=original_pipeline)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
}
)
service = PipelineService(ap)
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'test-uuid',
'extensions_preferences': {
'enable_all_plugins': False,
'plugins': [{'plugin_uuid': 'plugin-1'}],
}
}
)
# Execute
bound_plugins = [{'plugin_uuid': 'plugin-1'}]
await service.update_pipeline_extensions(
'test-uuid',
bound_plugins=bound_plugins,
enable_all_plugins=False,
)
# Verify
ap.persistence_mgr.execute_async.assert_called()
async def test_update_extensions_sets_mcp_servers(self):
"""Updates MCP servers in extensions_preferences."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline()
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=original_pipeline)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'test-uuid',
'extensions_preferences': {
'enable_all_mcp_servers': False,
'mcp_servers': ['mcp-server-1'],
}
}
)
service = PipelineService(ap)
service.get_pipeline = AsyncMock(
return_value={
'uuid': 'test-uuid',
'extensions_preferences': {'mcp_servers': ['mcp-server-1']},
}
)
# Execute
await service.update_pipeline_extensions(
'test-uuid',
bound_plugins=[],
bound_mcp_servers=['mcp-server-1'],
enable_all_mcp_servers=False,
)
# Verify
ap.persistence_mgr.execute_async.assert_called()
async def test_update_extensions_none_mcp_servers_keeps_existing(self):
"""Does not modify mcp_servers when bound_mcp_servers is None."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.pipeline_mgr = SimpleNamespace()
ap.pipeline_mgr.remove_pipeline = AsyncMock()
ap.pipeline_mgr.load_pipeline = AsyncMock()
original_pipeline = _create_mock_pipeline(
extensions_preferences={
'enable_all_plugins': True,
'enable_all_mcp_servers': True,
'plugins': [],
'mcp_servers': ['existing-server'],
}
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return _create_mock_result(first_item=original_pipeline)
return Mock()
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
)
service = PipelineService(ap)
service.get_pipeline = AsyncMock(
return_value={'uuid': 'test-uuid', 'extensions_preferences': {'mcp_servers': ['existing-server']}}
)
# Execute - bound_mcp_servers is None (not provided)
await service.update_pipeline_extensions('test-uuid', bound_plugins=[])
# Verify - persistence was called
ap.persistence_mgr.execute_async.assert_called()
class TestDefaultStageOrder:
"""Tests for default_stage_order constant."""
def test_default_stage_order_not_empty(self):
"""Default stage order is not empty."""
assert len(default_stage_order) > 0
def test_default_stage_order_contains_key_stages(self):
"""Default stage order contains key processing stages."""
assert 'MessageProcessor' in default_stage_order
assert 'SendResponseBackStage' in default_stage_order

View File

@@ -1,866 +0,0 @@
"""
Unit tests for ModelProviderService.
Tests model provider management operations including:
- Provider CRUD operations
- Provider model count checking
- Find or create provider logic
- Space model provider API key updates
- Provider model scanning
Source: src/langbot/pkg/api/http/service/provider.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.api.http.service.provider import ModelProviderService
from langbot.pkg.entity.persistence.model import ModelProvider, LLMModel, EmbeddingModel, RerankModel
pytestmark = pytest.mark.asyncio
def _create_mock_provider(
provider_uuid: str = 'test-provider-uuid',
name: str = 'Test Provider',
requester: str = 'openai',
base_url: str = 'https://api.openai.com',
api_keys: list = None,
) -> Mock:
"""Helper to create mock ModelProvider entity."""
provider = Mock(spec=ModelProvider)
provider.uuid = provider_uuid
provider.name = name
provider.requester = requester
provider.base_url = base_url
provider.api_keys = api_keys or ['test-key']
return provider
def _create_mock_llm_model(
model_uuid: str = 'test-llm-uuid',
name: str = 'Test LLM',
provider_uuid: str = 'test-provider-uuid',
) -> Mock:
"""Helper to create mock LLMModel entity."""
model = Mock(spec=LLMModel)
model.uuid = model_uuid
model.name = name
model.provider_uuid = provider_uuid
return model
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
result.scalar = Mock(return_value=len(items) if items else 0)
return result
class TestModelProviderServiceGetProviders:
"""Tests for get_providers method."""
async def test_get_providers_empty_list(self):
"""Returns empty list when no providers exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'requester': entity.requester,
'base_url': entity.base_url,
'api_keys': entity.api_keys,
}
)
service = ModelProviderService(ap)
# Execute
result = await service.get_providers()
# Verify
assert result == []
async def test_get_providers_returns_serialized_list(self):
"""Returns serialized list of providers."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
provider1 = _create_mock_provider(provider_uuid='provider-1', name='Provider 1')
provider2 = _create_mock_provider(provider_uuid='provider-2', name='Provider 2')
mock_result = _create_mock_result([provider1, provider2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'requester': entity.requester,
'base_url': entity.base_url,
'api_keys': entity.api_keys,
}
)
service = ModelProviderService(ap)
# Execute
result = await service.get_providers()
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Provider 1'
assert result[1]['name'] == 'Provider 2'
async def test_get_providers_parse_api_keys_json_string(self):
"""Parses api_keys from JSON string if needed."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='["key1", "key2"]')
mock_result = _create_mock_result([provider])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'api_keys': entity.api_keys, # Returns string
}
)
service = ModelProviderService(ap)
# Execute
result = await service.get_providers()
# Verify - api_keys should be parsed from string
assert result[0]['api_keys'] == ['key1', 'key2']
async def test_get_providers_invalid_json_api_keys_returns_empty(self):
"""Returns empty list for invalid JSON api_keys."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='provider-1', api_keys='invalid-json')
mock_result = _create_mock_result([provider])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'api_keys': entity.api_keys, # Returns invalid string
}
)
service = ModelProviderService(ap)
# Execute
result = await service.get_providers()
# Verify - invalid JSON returns empty list
assert result[0]['api_keys'] == []
class TestModelProviderServiceGetProvider:
"""Tests for get_provider method."""
async def test_get_provider_by_uuid_found(self):
"""Returns provider when found by UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='found-uuid', name='Found Provider')
mock_result = _create_mock_result([], first_item=provider)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-uuid',
'name': 'Found Provider',
'api_keys': ['key'],
}
)
service = ModelProviderService(ap)
# Execute
result = await service.get_provider('found-uuid')
# Verify
assert result is not None
assert result['uuid'] == 'found-uuid'
async def test_get_provider_by_uuid_not_found(self):
"""Returns None when provider not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([], first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute
result = await service.get_provider('nonexistent-uuid')
# Verify
assert result is None
class TestModelProviderServiceCreateProvider:
"""Tests for create_provider method."""
async def test_create_provider_generates_uuid(self):
"""Creates provider with generated UUID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
# Mock load_provider to return runtime provider
runtime_provider = Mock()
runtime_provider.provider_entity = Mock()
runtime_provider.provider_entity.uuid = 'generated-uuid'
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
ap.persistence_mgr.execute_async = AsyncMock()
service = ModelProviderService(ap)
# Execute
provider_uuid = await service.create_provider({
'name': 'New Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
# Verify - UUID is generated
assert provider_uuid is not None
assert len(provider_uuid) == 36 # UUID format
async def test_create_provider_loads_to_runtime(self):
"""Loads provider to runtime model_mgr."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
runtime_provider = Mock()
runtime_provider.provider_entity = Mock()
runtime_provider.provider_entity.uuid = 'runtime-uuid'
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
ap.persistence_mgr.execute_async = AsyncMock()
service = ModelProviderService(ap)
# Execute
result_uuid = await service.create_provider({
'name': 'Runtime Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
})
# Verify - provider added to runtime dict and UUID generated
ap.model_mgr.load_provider.assert_called_once()
assert result_uuid is not None
class TestModelProviderServiceUpdateProvider:
"""Tests for update_provider method."""
async def test_update_provider_removes_uuid_from_data(self):
"""Removes uuid from update data before persisting."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.reload_provider = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = ModelProviderService(ap)
# Execute
await service.update_provider('existing-uuid', {
'uuid': 'should-be-removed', # Will be removed
'name': 'Updated Name',
})
# Verify - reload called
ap.model_mgr.reload_provider.assert_called_once_with('existing-uuid')
async def test_update_provider_reloads_runtime(self):
"""Reloads provider in runtime after update."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.reload_provider = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = ModelProviderService(ap)
# Execute
await service.update_provider('update-uuid', {'name': 'New Name'})
# Verify
ap.model_mgr.reload_provider.assert_called_once()
class TestModelProviderServiceDeleteProvider:
"""Tests for delete_provider method."""
async def test_delete_provider_with_llm_models_raises_error(self):
"""Raises ValueError when LLM models reference provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Mock LLM model exists - only return LLM result since that's first check
llm_result = _create_mock_result([], first_item=_create_mock_llm_model())
ap.persistence_mgr.execute_async = AsyncMock(return_value=llm_result)
service = ModelProviderService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='Cannot delete provider: LLM models'):
await service.delete_provider('provider-with-llm')
async def test_delete_provider_with_embedding_models_raises_error(self):
"""Raises ValueError when Embedding models reference provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Create results for each check type
llm_result = Mock()
llm_result.first = Mock(return_value=None) # No LLM models
embedding_result = Mock()
embedding_result.first = Mock(return_value=Mock(spec=EmbeddingModel)) # Has embedding model
rerank_result = Mock()
rerank_result.first = Mock(return_value=None)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return llm_result
elif call_count == 2:
return embedding_result
return rerank_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = ModelProviderService(ap)
# Execute & Verify - should raise embedding error (LLM check passes, embedding check fails)
with pytest.raises(ValueError, match='Cannot delete provider: Embedding models'):
await service.delete_provider('provider-with-embedding')
async def test_delete_provider_with_rerank_models_raises_error(self):
"""Raises ValueError when Rerank models reference provider."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Create results for each check type
llm_result = Mock()
llm_result.first = Mock(return_value=None) # No LLM models
embedding_result = Mock()
embedding_result.first = Mock(return_value=None) # No embedding models
rerank_result = Mock()
rerank_result.first = Mock(return_value=Mock(spec=RerankModel)) # Has rerank model
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return llm_result
elif call_count == 2:
return embedding_result
return rerank_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = ModelProviderService(ap)
# Execute & Verify - should raise rerank error (LLM and embedding checks pass, rerank check fails)
with pytest.raises(ValueError, match='Cannot delete provider: Rerank models'):
await service.delete_provider('provider-with-rerank')
async def test_delete_provider_no_models_success(self):
"""Deletes provider when no models reference it."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.remove_provider = AsyncMock()
# Mock no models reference provider
empty_result = Mock()
empty_result.first = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=empty_result)
service = ModelProviderService(ap)
# Execute
await service.delete_provider('provider-no-models')
# Verify - delete and remove called
ap.model_mgr.remove_provider.assert_called_once_with('provider-no-models')
class TestModelProviderServiceGetProviderModelCounts:
"""Tests for get_provider_model_counts method."""
async def test_get_model_counts_returns_correct_counts(self):
"""Returns correct counts for each model type."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Mock scalar results for counts
llm_result = Mock()
llm_result.scalar = Mock(return_value=3)
embedding_result = Mock()
embedding_result.scalar = Mock(return_value=2)
rerank_result = Mock()
rerank_result.scalar = Mock(return_value=1)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return llm_result
elif call_count == 2:
return embedding_result
return rerank_result
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
service = ModelProviderService(ap)
# Execute
result = await service.get_provider_model_counts('provider-uuid')
# Verify
assert result['llm_count'] == 3
assert result['embedding_count'] == 2
assert result['rerank_count'] == 1
async def test_get_model_counts_zero_counts(self):
"""Returns zero counts when no models."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
zero_result = Mock()
zero_result.scalar = Mock(return_value=0)
ap.persistence_mgr.execute_async = AsyncMock(return_value=zero_result)
service = ModelProviderService(ap)
# Execute
result = await service.get_provider_model_counts('empty-provider')
# Verify
assert result['llm_count'] == 0
assert result['embedding_count'] == 0
assert result['rerank_count'] == 0
class TestModelProviderServiceFindOrCreateProvider:
"""Tests for find_or_create_provider method."""
async def test_find_existing_provider_matching_config(self):
"""Returns existing provider UUID when config matches."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
existing_provider = _create_mock_provider(
provider_uuid='existing-uuid',
requester='openai',
base_url='https://api.openai.com',
api_keys=['key1', 'key2'],
)
mock_result = _create_mock_result([existing_provider])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute
result = await service.find_or_create_provider(
requester='openai',
base_url='https://api.openai.com',
api_keys=['key1', 'key2'], # Same keys (sorted)
)
# Verify - returns existing UUID
assert result == 'existing-uuid'
async def test_find_existing_provider_keys_order_mismatch(self):
"""Returns existing provider when keys match but order differs."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
existing_provider = _create_mock_provider(
provider_uuid='existing-uuid',
requester='openai',
base_url='https://api.openai.com',
api_keys=['key1', 'key2'],
)
mock_result = _create_mock_result([existing_provider])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute with reversed key order
result = await service.find_or_create_provider(
requester='openai',
base_url='https://api.openai.com',
api_keys=['key2', 'key1'], # Different order, should still match
)
# Verify - returns existing UUID (keys are sorted in comparison)
assert result == 'existing-uuid'
async def test_create_new_provider_no_match(self):
"""Creates new provider when no existing match."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
runtime_provider = Mock()
runtime_provider.provider_entity = Mock()
runtime_provider.provider_entity.uuid = None # Will be set by uuid.uuid4()
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock no existing providers
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute
result = await service.find_or_create_provider(
requester='new-requester',
base_url='https://new.api.com',
api_keys=['new-key'],
)
# Verify - creates new provider with valid UUID format
assert result is not None
assert len(result) == 36 # UUID format
# Verify provider was loaded to runtime
ap.model_mgr.load_provider.assert_called_once()
async def test_create_provider_name_from_url_parse(self):
"""Creates provider with name parsed from URL."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {}
runtime_provider = Mock()
runtime_provider.provider_entity = Mock()
runtime_provider.provider_entity.uuid = 'parsed-url-uuid'
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute
result_uuid = await service.find_or_create_provider(
requester='custom',
base_url='https://api.example.com/v1',
api_keys=['key'],
)
# Verify - name should be parsed from URL (api.example.com)
ap.model_mgr.load_provider.assert_called_once()
assert result_uuid is not None
class TestModelProviderServiceUpdateSpaceModelProviderApiKeys:
"""Tests for update_space_model_provider_api_keys method."""
async def test_update_space_provider_api_keys(self):
"""Updates Space provider API keys."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.reload_provider = AsyncMock()
ap.persistence_mgr.execute_async = AsyncMock()
service = ModelProviderService(ap)
# Execute
await service.update_space_model_provider_api_keys('space-api-key')
# Verify - update and reload called for Space provider UUID
ap.model_mgr.reload_provider.assert_called_once_with(
'00000000-0000-0000-0000-000000000000'
)
class TestModelProviderServiceScanProviderModels:
"""Tests for scan_provider_models method."""
async def test_scan_provider_not_found_raises_error(self):
"""Raises ValueError when provider not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([], first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = ModelProviderService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='provider not found'):
await service.scan_provider_models('nonexistent-uuid')
async def test_scan_provider_returns_models_list(self):
"""Returns scanned models list."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.llm_model_service = SimpleNamespace()
ap.embedding_models_service = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='scan-uuid')
mock_result = _create_mock_result([], first_item=provider)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'scan-uuid',
'name': 'Scan Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
# Mock runtime provider with scan capability
runtime_provider = Mock()
runtime_provider.requester = Mock()
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
# Mock scan_models to return models
async def mock_scan_models(token):
return {
'models': [
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
],
'debug': None,
}
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing model services
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
# Execute
result = await service.scan_provider_models('scan-uuid')
# Verify
assert 'models' in result
assert len(result['models']) == 2
async def test_scan_provider_filter_by_model_type(self):
"""Returns filtered models by type."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.llm_model_service = SimpleNamespace()
ap.embedding_models_service = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='filter-uuid')
mock_result = _create_mock_result([], first_item=provider)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'filter-uuid',
'name': 'Filter Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
runtime_provider = Mock()
runtime_provider.requester = Mock()
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
async def mock_scan_models(token):
return {
'models': [
{'id': 'gpt-4', 'name': 'GPT-4', 'type': 'llm'},
{'id': 'text-embedding', 'name': 'Text Embedding', 'type': 'embedding'},
],
'debug': None,
}
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(return_value=[])
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
# Execute - filter for LLM only
result = await service.scan_provider_models('filter-uuid', model_type='llm')
# Verify - only LLM models returned
assert len(result['models']) == 1
assert result['models'][0]['type'] == 'llm'
async def test_scan_provider_not_implemented_raises_error(self):
"""Raises ValueError when scan not implemented."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='no-scan-uuid')
mock_result = _create_mock_result([], first_item=provider)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'no-scan-uuid',
'name': 'No Scan Provider',
'requester': 'custom',
'base_url': 'https://custom.api.com',
'api_keys': ['key'],
}
)
runtime_provider = Mock()
runtime_provider.requester = Mock()
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
runtime_provider.requester.scan_models = AsyncMock(
side_effect=NotImplementedError('scan not supported')
)
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
service = ModelProviderService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='current provider does not support model scanning'):
await service.scan_provider_models('no-scan-uuid')
async def test_scan_provider_marks_already_added_models(self):
"""Marks models that are already added."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.llm_model_service = SimpleNamespace()
ap.embedding_models_service = SimpleNamespace()
provider = _create_mock_provider(provider_uuid='already-added-uuid')
mock_result = _create_mock_result([], first_item=provider)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'already-added-uuid',
'name': 'Already Added Provider',
'requester': 'openai',
'base_url': 'https://api.openai.com',
'api_keys': ['key'],
}
)
runtime_provider = Mock()
runtime_provider.requester = Mock()
runtime_provider.token_mgr = Mock()
runtime_provider.token_mgr.get_token = Mock(return_value='token')
runtime_provider.token_mgr.tokens = ['token']
async def mock_scan_models(token):
return {
'models': [
{'id': 'existing-model', 'name': 'Existing Model', 'type': 'llm'},
{'id': 'new-model', 'name': 'New Model', 'type': 'llm'},
],
'debug': None,
}
runtime_provider.requester.scan_models = AsyncMock(side_effect=mock_scan_models)
ap.model_mgr.load_provider = AsyncMock(return_value=runtime_provider)
# Mock existing LLM model
ap.llm_model_service.get_llm_models_by_provider = AsyncMock(
return_value=[{'name': 'Existing Model'}]
)
ap.embedding_models_service.get_embedding_models_by_provider = AsyncMock(return_value=[])
service = ModelProviderService(ap)
# Execute
result = await service.scan_provider_models('already-added-uuid')
# Verify - existing model marked as already_added
existing_model = next(m for m in result['models'] if m['name'] == 'Existing Model')
assert existing_model['already_added'] is True
new_model = next(m for m in result['models'] if m['name'] == 'New Model')
assert new_model['already_added'] is False

View File

@@ -1,778 +0,0 @@
"""
Unit tests for SpaceService.
Tests LangBot Space API interactions including:
- OAuth URL generation
- Token exchange and refresh
- User info retrieval
- Credits caching
- Model listing
Source: src/langbot/pkg/api/http/service/space.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from types import SimpleNamespace
import datetime
import time
from langbot.pkg.api.http.service.space import SpaceService
from langbot.pkg.entity.persistence.user import User
pytestmark = pytest.mark.asyncio
def _create_mock_user(
email: str = 'test@example.com',
account_type: str = 'space',
space_account_uuid: str = 'space-uuid-123',
space_access_token: str = 'access_token_123',
space_refresh_token: str = 'refresh_token_123',
space_access_token_expires_at: datetime.datetime = None,
) -> Mock:
"""Helper to create mock User entity."""
user = Mock(spec=User)
user.user = email
user.account_type = account_type
user.space_account_uuid = space_account_uuid
user.space_access_token = space_access_token
user.space_refresh_token = space_refresh_token
user.space_access_token_expires_at = space_access_token_expires_at
return user
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestSpaceServiceGetOAuthAuthorizeUrl:
"""Tests for get_oauth_authorize_url method."""
def test_get_oauth_authorize_url_basic(self):
"""Returns OAuth URL with redirect_uri."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'space': {
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
}
}
service = SpaceService(ap)
# Execute
result = service.get_oauth_authorize_url('http://localhost/callback')
# Verify
assert 'redirect_uri=http://localhost/callback' in result
assert 'https://space.langbot.app/auth/authorize' in result
def test_get_oauth_authorize_url_with_state(self):
"""Returns OAuth URL with redirect_uri and state."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {
'space': {
'oauth_authorize_url': 'https://space.langbot.app/auth/authorize',
}
}
service = SpaceService(ap)
# Execute
result = service.get_oauth_authorize_url('http://localhost/callback', state='random_state')
# Verify
assert 'redirect_uri=http://localhost/callback' in result
assert 'state=random_state' in result
def test_get_oauth_authorize_url_default_config(self):
"""Uses default OAuth URL when config not set."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Execute
result = service.get_oauth_authorize_url('http://localhost/callback')
# Verify - uses default URL
assert 'https://space.langbot.app/auth/authorize' in result
class TestSpaceServiceGetUserByEmail:
"""Tests for _get_user_by_email internal method."""
async def test_get_user_by_email_found(self):
"""Returns user when found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(email='found@example.com')
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._get_user_by_email('found@example.com')
# Verify
assert result is not None
assert result.user == 'found@example.com'
async def test_get_user_by_email_not_found(self):
"""Returns None when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._get_user_by_email('notfound@example.com')
# Verify
assert result is None
class TestSpaceServiceEnsureValidToken:
"""Tests for _ensure_valid_token internal method."""
async def test_ensure_valid_token_user_not_found(self):
"""Returns None when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._ensure_valid_token('notfound@example.com')
# Verify
assert result is None
async def test_ensure_valid_token_not_space_account(self):
"""Returns None when user is not a space account."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(email='local@example.com', account_type='local')
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._ensure_valid_token('local@example.com')
# Verify
assert result is None
async def test_ensure_valid_token_no_access_token(self):
"""Returns None when user has no access token."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(space_access_token=None)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._ensure_valid_token('test@example.com')
# Verify
assert result is None
async def test_ensure_valid_token_valid_token(self):
"""Returns valid access token when not expired."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Token expires in 1 hour (valid)
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._ensure_valid_token('test@example.com')
# Verify
assert result == 'valid_token'
async def test_ensure_valid_token_expired_no_refresh(self):
"""Returns None when token expired and no refresh token."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Token expired 1 hour ago
mock_user = _create_mock_user(
space_access_token='expired_token',
space_refresh_token=None,
space_access_token_expires_at=datetime.datetime.now() - datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service._ensure_valid_token('test@example.com')
# Verify
assert result is None
class TestSpaceServiceGetCredits:
"""Tests for get_credits method."""
async def test_get_credits_no_user(self):
"""Returns None when user not found."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service.get_credits('notfound@example.com')
# Verify
assert result is None
async def test_get_credits_returns_cached_value(self):
"""Returns cached credits without API call."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Pre-populate cache
service._credits_cache = {'cached@example.com': (100, time.time())}
# Execute
result = await service.get_credits('cached@example.com')
# Verify - returns cached value without API call
assert result == 100
async def test_get_credits_cache_expired_refreshes(self):
"""Refreshes expired cache."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Pre-populate expired cache (70 seconds ago, past 60s TTL)
service._credits_cache = {'test@example.com': (50, time.time() - 70)}
# Mock get_user_info to return new credits
service.get_user_info = AsyncMock(return_value={'credits': 200})
# Execute
result = await service.get_credits('test@example.com')
# Verify - cache was refreshed
assert result == 200
assert service._credits_cache['test@example.com'][0] == 200
async def test_get_credits_force_refresh(self):
"""Force refresh ignores cache."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Pre-populate cache
service._credits_cache = {'test@example.com': (100, time.time())}
# Mock get_user_info to return new credits
service.get_user_info = AsyncMock(return_value={'credits': 300})
# Execute with force_refresh=True
result = await service.get_credits('test@example.com', force_refresh=True)
# Verify - fresh value returned
assert result == 300
async def test_get_credits_returns_cached_on_exception(self):
"""Returns cached fallback value when API fails."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Pre-populate expired cache - will try to refresh and fail
service._credits_cache = {'test@example.com': (150, time.time() - 70)}
# Mock get_user_info to raise exception
service.get_user_info = AsyncMock(side_effect=Exception('API Error'))
# Execute - should return cached fallback value (even though expired)
result = await service.get_credits('test@example.com')
# Verify - returns cached fallback value (150) because API failed
assert result == 150
class TestSpaceServiceRefreshToken:
"""Tests for refresh_token method."""
async def test_refresh_token_success(self):
"""Refreshes token successfully."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
}
})
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.post = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
# Use async context manager mock
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute
result = await service.refresh_token('old_refresh_token')
# Verify
assert result['access_token'] == 'new_access_token'
async def test_refresh_token_api_error(self):
"""Raises ValueError on API error."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 1,
'msg': 'Invalid refresh token',
})
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid refresh token"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.post = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='Failed to refresh token'):
await service.refresh_token('invalid_refresh_token')
async def test_refresh_token_http_error(self):
"""Raises ValueError on HTTP error."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with error status
mock_response = MagicMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value='Internal Server Error')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.post = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='Failed to refresh token'):
await service.refresh_token('refresh_token')
class TestSpaceServiceExchangeOAuthCode:
"""Tests for exchange_oauth_code method."""
async def test_exchange_oauth_code_success(self):
"""Exchanges OAuth code successfully."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'access_token': 'new_access_token',
'refresh_token': 'new_refresh_token',
'expires_in': 3600,
}
})
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.post = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute
result = await service.exchange_oauth_code('auth_code')
# Verify
assert result['access_token'] == 'new_access_token'
async def test_exchange_oauth_code_api_error(self):
"""Raises ValueError on API error."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Invalid code'})
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Invalid code"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.post = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.post.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='Failed to exchange OAuth code'):
await service.exchange_oauth_code('invalid_code')
class TestSpaceServiceGetUserInfoRaw:
"""Tests for get_user_info_raw method."""
async def test_get_user_info_raw_success(self):
"""Gets user info successfully."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'email': 'test@example.com',
'credits': 100,
}
})
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.get = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute
result = await service.get_user_info_raw('access_token')
# Verify
assert result['email'] == 'test@example.com'
assert result['credits'] == 100
async def test_get_user_info_raw_api_error(self):
"""Raises ValueError on API error."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.get = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='Failed to get user info'):
await service.get_user_info_raw('invalid_token')
class TestSpaceServiceGetUserInfo:
"""Tests for get_user_info method (with token validation)."""
async def test_get_user_info_no_token(self):
"""Returns None when no valid token."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Execute
result = await service.get_user_info('notfound@example.com')
# Verify
assert result is None
async def test_get_user_info_with_valid_token(self):
"""Returns user info with valid token."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Mock get_user_info_raw
service.get_user_info_raw = AsyncMock(return_value={'email': 'test@example.com', 'credits': 100})
# Execute
result = await service.get_user_info('test@example.com')
# Verify
assert result['email'] == 'test@example.com'
class TestSpaceServiceGetModels:
"""Tests for get_models method."""
async def test_get_models_success(self):
"""Gets models successfully."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with proper model data matching SpaceModel schema
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'code': 0,
'data': {
'models': [
{
'uuid': 'uuid-1',
'model_id': 'model-1',
'provider': 'provider-1',
'category': 'chat',
'status': 'active',
},
{
'uuid': 'uuid-2',
'model_id': 'model-2',
'provider': 'provider-2',
'category': 'chat',
'status': 'active',
},
]
}
})
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.get = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute
result = await service.get_models()
# Verify
assert len(result) == 2
async def test_get_models_api_error(self):
"""Raises ValueError on API error."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Mock HTTP response with error
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={'code': 1, 'msg': 'Unauthorized'})
mock_response.text = AsyncMock(return_value='{"code":1,"msg":"Unauthorized"}')
with patch('langbot.pkg.api.http.service.space.httpclient.get_session') as mock_session:
mock_session_obj = MagicMock()
mock_session_obj.get = MagicMock(return_value=mock_response)
mock_session.return_value = mock_session_obj
mock_session_obj.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session_obj.get.return_value.__aexit__ = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='Failed to get models'):
await service.get_models()
class TestSpaceServiceCreditsCache:
"""Tests for credits cache behavior."""
def test_credits_cache_initialized(self):
"""Verify _credits_cache is initialized as empty dict."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
service = SpaceService(ap)
# Verify
assert hasattr(service, '_credits_cache')
assert service._credits_cache == {}
async def test_credits_cache_updates_on_success(self):
"""Cache updates when get_credits succeeds."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {}
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
space_access_token='valid_token',
space_access_token_expires_at=datetime.datetime.now() + datetime.timedelta(hours=1),
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = SpaceService(ap)
# Mock get_user_info
service.get_user_info = AsyncMock(return_value={'credits': 500})
# Execute
result = await service.get_credits('test@example.com')
# Verify - cache updated
assert result == 500
assert 'test@example.com' in service._credits_cache
assert service._credits_cache['test@example.com'][0] == 500

View File

@@ -1,608 +0,0 @@
"""
Unit tests for UserService.
Tests user management operations including:
- User initialization check
- Local user creation and authentication
- JWT token generation and verification
- Password management (reset, change, set)
- Space account management
Source: src/langbot/pkg/api/http/service/user.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.api.http.service.user import UserService
from langbot.pkg.entity.persistence.user import User
from langbot.pkg.entity.errors.account import AccountEmailMismatchError
pytestmark = pytest.mark.asyncio
def _create_mock_user(
email: str = 'test@example.com',
password: str = 'hashed_password',
account_type: str = 'local',
space_account_uuid: str = None,
) -> Mock:
"""Helper to create mock User entity."""
user = Mock(spec=User)
user.user = email
user.password = password
user.account_type = account_type
user.space_account_uuid = space_account_uuid
return user
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestUserServiceIsInitialized:
"""Tests for is_initialized method."""
async def test_is_initialized_returns_true_when_users_exist(self):
"""Returns True when at least one user exists."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user()
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.is_initialized()
# Verify
assert result is True
async def test_is_initialized_returns_false_when_no_users(self):
"""Returns False when no users exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.is_initialized()
# Verify
assert result is False
async def test_is_initialized_returns_false_on_none_result(self):
"""Returns False when result is None."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = Mock()
mock_result.all = Mock(return_value=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.is_initialized()
# Verify
assert result is False
class TestUserServiceGetUserByEmail:
"""Tests for get_user_by_email method."""
async def test_get_user_by_email_found(self):
"""Returns user when found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(email='found@example.com')
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_user_by_email('found@example.com')
# Verify
assert result is not None
assert result.user == 'found@example.com'
async def test_get_user_by_email_not_found(self):
"""Returns None when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_user_by_email('notfound@example.com')
# Verify
assert result is None
async def test_get_user_by_email_empty_string(self):
"""Handles empty email string."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_user_by_email('')
# Verify
assert result is None
class TestUserServiceGetUserBySpaceAccountUuid:
"""Tests for get_user_by_space_account_uuid method."""
async def test_get_user_by_space_uuid_found(self):
"""Returns user when Space UUID found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(
email='space@example.com',
account_type='space',
space_account_uuid='space-uuid-123',
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_user_by_space_account_uuid('space-uuid-123')
# Verify
assert result is not None
assert result.space_account_uuid == 'space-uuid-123'
async def test_get_user_by_space_uuid_not_found(self):
"""Returns None when Space UUID not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_user_by_space_account_uuid('nonexistent-uuid')
# Verify
assert result is None
class TestUserServiceAuthenticate:
"""Tests for authenticate method."""
async def test_authenticate_user_not_found_raises_error(self):
"""Raises ValueError when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
service = UserService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='用户不存在'):
await service.authenticate('nonexistent@example.com', 'password')
async def test_authenticate_space_user_without_password_raises_error(self):
"""Raises ValueError for Space user without local password."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Space user has empty password
mock_user = _create_mock_user(
email='space@example.com',
password='', # Empty password for Space user
account_type='space',
)
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute & Verify
with pytest.raises(ValueError, match='请使用 Space 账户登录'):
await service.authenticate('space@example.com', 'password')
class TestUserServiceGenerateJwtToken:
"""Tests for generate_jwt_token method."""
async def test_generate_jwt_token_returns_valid_token(self):
"""Generates valid JWT token."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
service = UserService(ap)
# Execute
token = await service.generate_jwt_token('test@example.com')
# Verify - JWT format (base64 encoded parts)
assert token is not None
assert len(token) > 0
parts = token.split('.')
assert len(parts) == 3 # JWT has 3 parts
async def test_generate_jwt_token_custom_expire(self):
"""Generates token with custom expiry."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 7200}}}
service = UserService(ap)
# Execute
token = await service.generate_jwt_token('test@example.com')
# Verify
assert token is not None
class TestUserServiceVerifyJwtToken:
"""Tests for verify_jwt_token method."""
async def test_verify_jwt_token_valid(self):
"""Verifies valid JWT token and returns user email."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
service = UserService(ap)
# First generate a valid token
token = await service.generate_jwt_token('verify@example.com')
# Execute
user_email = await service.verify_jwt_token(token)
# Verify
assert user_email == 'verify@example.com'
async def test_verify_jwt_token_invalid_raises_error(self):
"""Raises error for invalid JWT token."""
# Setup
ap = SimpleNamespace()
ap.instance_config = SimpleNamespace()
ap.instance_config.data = {'system': {'jwt': {'secret': 'test_secret', 'expire': 3600}}}
service = UserService(ap)
# Execute & Verify - invalid token should raise JWT error
with pytest.raises(Exception): # jwt.DecodeError or similar
await service.verify_jwt_token('invalid.token.here')
class TestUserServiceResetPassword:
"""Tests for reset_password method."""
async def test_reset_password_updates_password(self):
"""Updates user password."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = UserService(ap)
# Execute
await service.reset_password('test@example.com', 'new_password')
# Verify - execute_async was called with update
ap.persistence_mgr.execute_async.assert_called_once()
class TestUserServiceChangePassword:
"""Tests for change_password method."""
async def test_change_password_user_not_found_raises_error(self):
"""Raises ValueError when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
service = UserService(ap)
# Mock get_user_by_email to return None
service.get_user_by_email = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='User not found'):
await service.change_password('nonexistent@example.com', 'current', 'new')
async def test_change_password_no_local_password_raises_error(self):
"""Raises ValueError when user has no local password set."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
service = UserService(ap)
# Mock user without password
mock_user = _create_mock_user(email='nopass@example.com', password=None)
service.get_user_by_email = AsyncMock(return_value=mock_user)
# Execute & Verify
with pytest.raises(ValueError, match='No local password set'):
await service.change_password('nopass@example.com', 'current', 'new')
class TestUserServiceGetFirstUser:
"""Tests for get_first_user method."""
async def test_get_first_user_found(self):
"""Returns first user when exists."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_user = _create_mock_user(email='first@example.com')
mock_result = _create_mock_result([mock_user])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_first_user()
# Verify
assert result is not None
assert result.user == 'first@example.com'
async def test_get_first_user_not_found(self):
"""Returns None when no users exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = UserService(ap)
# Execute
result = await service.get_first_user()
# Verify
assert result is None
class TestUserServiceSetPassword:
"""Tests for set_password method."""
async def test_set_password_user_not_found_raises_error(self):
"""Raises ValueError when user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
service = UserService(ap)
# Mock get_user_by_email to return None
service.get_user_by_email = AsyncMock(return_value=None)
# Execute & Verify
with pytest.raises(ValueError, match='User not found'):
await service.set_password('nonexistent@example.com', 'new_password')
async def test_set_password_with_existing_password_requires_current(self):
"""Requires current password when user has existing password."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
service = UserService(ap)
# Mock user with existing password
mock_user = _create_mock_user(email='haspass@example.com', password='hashed_old_password')
service.get_user_by_email = AsyncMock(return_value=mock_user)
# Execute & Verify - should raise when no current_password provided
with pytest.raises(ValueError, match='Current password is required'):
await service.set_password('haspass@example.com', 'new_password')
class TestUserServiceCreateOrUpdateSpaceUser:
"""Tests for create_or_update_space_user method."""
async def test_create_or_update_existing_space_user(self):
"""Updates existing Space user tokens."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.provider_service = SimpleNamespace()
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
service = UserService(ap)
# Mock existing Space user
existing_user = _create_mock_user(
email='space@example.com',
account_type='space',
space_account_uuid='existing-space-uuid',
)
service.get_user_by_space_account_uuid = AsyncMock(return_value=existing_user)
service.get_user_by_email = AsyncMock(return_value=None)
service.is_initialized = AsyncMock(return_value=True)
ap.persistence_mgr.execute_async = AsyncMock()
# Execute
updated_user = await service.create_or_update_space_user(
space_account_uuid='existing-space-uuid',
email='space@example.com',
access_token='new_access_token',
refresh_token='new_refresh_token',
api_key='new_api_key',
expires_in=3600,
)
# Verify - update was called and user returned
ap.persistence_mgr.execute_async.assert_called()
assert updated_user.space_account_uuid == 'existing-space-uuid'
async def test_create_or_update_new_space_user_first_init(self):
"""Creates new Space user on first initialization."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.provider_service = SimpleNamespace()
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
service = UserService(ap)
# Mock new user to be returned after creation
new_user = _create_mock_user(
email='newspace@example.com',
account_type='space',
space_account_uuid='new-space-uuid',
)
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
if call_count == 1: # First check for existing user
return None
return new_user # After insert, return the new user
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
service.get_user_by_email = AsyncMock(return_value=None)
service.is_initialized = AsyncMock(return_value=False) # Not initialized
ap.persistence_mgr.execute_async = AsyncMock()
# Execute
result = await service.create_or_update_space_user(
space_account_uuid='new-space-uuid',
email='newspace@example.com',
access_token='access_token',
refresh_token='refresh_token',
api_key='api_key',
expires_in=3600,
)
# Verify
assert result.space_account_uuid == 'new-space-uuid'
async def test_create_or_update_space_user_already_initialized_raises_error(self):
"""Raises AccountEmailMismatchError when system already initialized and user not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.provider_service = SimpleNamespace()
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
service = UserService(ap)
# Mock system already initialized, no matching users
service.get_user_by_space_account_uuid = AsyncMock(return_value=None)
service.get_user_by_email = AsyncMock(return_value=None)
service.is_initialized = AsyncMock(return_value=True) # Already initialized
# Execute & Verify
with pytest.raises(AccountEmailMismatchError):
await service.create_or_update_space_user(
space_account_uuid='unknown-space-uuid',
email='unknown@example.com',
access_token='token',
refresh_token='refresh',
api_key='key',
expires_in=3600,
)
async def test_create_or_update_space_user_no_expiry(self):
"""Creates Space user without token expiry."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.provider_service = SimpleNamespace()
ap.provider_service.update_space_model_provider_api_keys = AsyncMock()
service = UserService(ap)
new_user = _create_mock_user(
email='noexpiry@example.com',
account_type='space',
space_account_uuid='noexpiry-uuid',
)
# First call (line 138) returns None, second call (line 194) returns new_user
call_count = 0
async def mock_get_by_space_uuid(uuid):
nonlocal call_count
call_count += 1
if call_count == 1: # First check for existing user
return None
return new_user # After insert, return the new user
service.get_user_by_space_account_uuid = AsyncMock(side_effect=mock_get_by_space_uuid)
service.get_user_by_email = AsyncMock(return_value=None)
service.is_initialized = AsyncMock(return_value=False)
ap.persistence_mgr.execute_async = AsyncMock()
# Execute with expires_in=0 (no expiry)
result = await service.create_or_update_space_user(
space_account_uuid='noexpiry-uuid',
email='noexpiry@example.com',
access_token='token',
refresh_token='refresh',
api_key='key',
expires_in=0, # No expiry
)
# Verify
assert result is not None
assert result.space_account_uuid == 'noexpiry-uuid'
class TestUserServiceCreateUserLock:
"""Tests for create_user_lock attribute."""
def test_create_user_lock_initialized(self):
"""Verify create_user_lock is initialized as asyncio.Lock."""
# Setup
ap = SimpleNamespace()
service = UserService(ap)
# Verify lock exists
assert hasattr(service, '_create_user_lock')
assert service._create_user_lock is not None

View File

@@ -1,506 +0,0 @@
"""
Unit tests for WebhookService.
Tests webhook CRUD operations including:
- Webhook listing
- Webhook creation
- Webhook retrieval by ID
- Webhook updates
- Webhook deletion
- Enabled webhooks filtering
Source: src/langbot/pkg/api/http/service/webhook.py
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, Mock
from types import SimpleNamespace
from langbot.pkg.api.http.service.webhook import WebhookService
from langbot.pkg.entity.persistence.webhook import Webhook
pytestmark = pytest.mark.asyncio
def _create_mock_webhook(
webhook_id: int = 1,
name: str = 'Test Webhook',
url: str = 'http://example.com/webhook',
description: str = 'Test Description',
enabled: bool = True,
) -> Mock:
"""Helper to create mock Webhook entity."""
webhook = Mock(spec=Webhook)
webhook.id = webhook_id
webhook.name = name
webhook.url = url
webhook.description = description
webhook.enabled = enabled
return webhook
def _create_mock_result(items: list = None, first_item=None):
"""Create mock result object for persistence queries."""
result = Mock()
result.all = Mock(return_value=items or [])
result.first = Mock(return_value=first_item)
return result
class TestWebhookServiceGetWebhooks:
"""Tests for get_webhooks method."""
async def test_get_webhooks_empty_list(self):
"""Returns empty list when no webhooks exist."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'url': entity.url,
}
)
service = WebhookService(ap)
# Execute
result = await service.get_webhooks()
# Verify
assert result == []
async def test_get_webhooks_returns_serialized_list(self):
"""Returns serialized list of webhooks."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
webhook1 = _create_mock_webhook(webhook_id=1, name='Webhook 1')
webhook2 = _create_mock_webhook(webhook_id=2, name='Webhook 2')
mock_result = _create_mock_result([webhook1, webhook2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'url': entity.url,
'description': entity.description,
'enabled': entity.enabled,
}
)
service = WebhookService(ap)
# Execute
result = await service.get_webhooks()
# Verify
assert len(result) == 2
assert result[0]['name'] == 'Webhook 1'
assert result[1]['name'] == 'Webhook 2'
class TestWebhookServiceCreateWebhook:
"""Tests for create_webhook method."""
async def test_create_webhook_full_params(self):
"""Creates webhook with all parameters."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Mock insert result
insert_result = Mock()
# Mock select result for retrieving created webhook
created_webhook = _create_mock_webhook(
webhook_id=1,
name='New Webhook',
url='http://new.example.com/webhook',
description='New Description',
enabled=True,
)
select_result = _create_mock_result(first_item=created_webhook)
# execute_async returns different results
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return insert_result # Insert
return select_result # Select
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'New Webhook',
'url': 'http://new.example.com/webhook',
'description': 'New Description',
'enabled': True,
}
)
service = WebhookService(ap)
# Execute
result = await service.create_webhook(
name='New Webhook',
url='http://new.example.com/webhook',
description='New Description',
enabled=True,
)
# Verify
assert result['name'] == 'New Webhook'
assert result['url'] == 'http://new.example.com/webhook'
assert result['description'] == 'New Description'
assert result['enabled'] is True
async def test_create_webhook_defaults(self):
"""Creates webhook with default description and enabled."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_webhook = _create_mock_webhook(
webhook_id=1,
name='Minimal Webhook',
url='http://minimal.example.com',
description='', # Default
enabled=True, # Default
)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return Mock() # Insert
return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'Minimal Webhook',
'url': 'http://minimal.example.com',
'description': '',
'enabled': True,
}
)
service = WebhookService(ap)
# Execute - only name and url required
result = await service.create_webhook(name='Minimal Webhook', url='http://minimal.example.com')
# Verify defaults
assert result['description'] == ''
assert result['enabled'] is True
async def test_create_webhook_disabled(self):
"""Creates webhook with enabled=False."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
created_webhook = _create_mock_webhook(webhook_id=1, enabled=False)
call_count = 0
async def mock_execute(query):
nonlocal call_count
call_count += 1
if call_count == 1:
return Mock()
return _create_mock_result(first_item=created_webhook)
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={'id': 1, 'enabled': False}
)
service = WebhookService(ap)
# Execute
result = await service.create_webhook(name='Disabled', url='http://disabled.com', enabled=False)
# Verify
assert result['enabled'] is False
class TestWebhookServiceGetWebhook:
"""Tests for get_webhook method."""
async def test_get_webhook_by_id_found(self):
"""Returns webhook when found by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
webhook = _create_mock_webhook(webhook_id=1, name='Found Webhook')
mock_result = _create_mock_result(first_item=webhook)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'id': 1,
'name': 'Found Webhook',
'url': 'http://example.com/webhook',
}
)
service = WebhookService(ap)
# Execute
result = await service.get_webhook(1)
# Verify
assert result is not None
assert result['id'] == 1
assert result['name'] == 'Found Webhook'
async def test_get_webhook_by_id_not_found(self):
"""Returns None when webhook not found."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = WebhookService(ap)
# Execute
result = await service.get_webhook(999)
# Verify
assert result is None
async def test_get_webhook_by_id_zero(self):
"""Handles ID=0 (edge case) correctly."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result(first_item=None)
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = WebhookService(ap)
# Execute
result = await service.get_webhook(0)
# Verify - should return None (no webhook with ID 0)
assert result is None
class TestWebhookServiceUpdateWebhook:
"""Tests for update_webhook method."""
async def test_update_webhook_name_only(self):
"""Updates only the name field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.update_webhook(1, name='Updated Name')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_webhook_url_only(self):
"""Updates only the url field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.update_webhook(1, url='http://updated.example.com')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_webhook_description_only(self):
"""Updates only the description field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.update_webhook(1, description='Updated description')
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_webhook_enabled_only(self):
"""Updates only the enabled field."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.update_webhook(1, enabled=False)
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_webhook_all_fields(self):
"""Updates all fields at once."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.update_webhook(
1,
name='All Updated',
url='http://all.updated.com',
description='All updated description',
enabled=False,
)
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_update_webhook_no_fields(self):
"""Does nothing when no fields provided."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute - no update parameters
await service.update_webhook(1)
# Verify - no execute call since no update_data
ap.persistence_mgr.execute_async.assert_not_called()
class TestWebhookServiceDeleteWebhook:
"""Tests for delete_webhook method."""
async def test_delete_webhook_by_id(self):
"""Deletes webhook by ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute
await service.delete_webhook(1)
# Verify
ap.persistence_mgr.execute_async.assert_called_once()
async def test_delete_webhook_nonexistent_id(self):
"""Delete operation completes even for nonexistent ID."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.persistence_mgr.execute_async = AsyncMock()
service = WebhookService(ap)
# Execute - should not raise
await service.delete_webhook(999)
# Verify - still called
ap.persistence_mgr.execute_async.assert_called_once()
class TestWebhookServiceGetEnabledWebhooks:
"""Tests for get_enabled_webhooks method."""
async def test_get_enabled_webhooks_empty(self):
"""Returns empty list when no enabled webhooks."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={})
service = WebhookService(ap)
# Execute
result = await service.get_enabled_webhooks()
# Verify
assert result == []
async def test_get_enabled_webhooks_filters_enabled(self):
"""Returns only enabled webhooks."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# All returned webhooks should be enabled (SQL filter)
webhook1 = _create_mock_webhook(webhook_id=1, name='Enabled 1', enabled=True)
webhook2 = _create_mock_webhook(webhook_id=2, name='Enabled 2', enabled=True)
mock_result = _create_mock_result([webhook1, webhook2])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(
side_effect=lambda model_cls, entity: {
'id': entity.id,
'name': entity.name,
'enabled': entity.enabled,
}
)
service = WebhookService(ap)
# Execute
result = await service.get_enabled_webhooks()
# Verify
assert len(result) == 2
assert all(w['enabled'] for w in result)
async def test_get_enabled_webhooks_filters_disabled(self):
"""Does not return disabled webhooks."""
# Setup
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
# Empty result because query filters on enabled=True
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
ap.persistence_mgr.serialize_model = Mock(return_value={})
service = WebhookService(ap)
# Execute
result = await service.get_enabled_webhooks()
# Verify - should be empty (SQL would filter disabled)
assert result == []

View File

@@ -1,40 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
from langbot.pkg.api.http.service.apikey import ApiKeyService
@pytest.mark.asyncio
@pytest.mark.parametrize('api_key', [None, 123, b'lbk_bytes', '', 'plain_key', ' LBK_bad', 'sk-lbk_fake'])
async def test_verify_api_key_rejects_non_lbk_keys_without_db_query(api_key):
persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
result = await service.verify_api_key(api_key)
assert result is False
persistence_mgr.execute_async.assert_not_awaited()
@pytest.mark.asyncio
@pytest.mark.parametrize(
('db_row', 'expected'),
[
(object(), True),
(None, False),
],
)
async def test_verify_api_key_keeps_db_validation_for_lbk_keys(db_row, expected):
query_result = Mock()
query_result.first.return_value = db_row
persistence_mgr = SimpleNamespace(execute_async=AsyncMock(return_value=query_result))
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
result = await service.verify_api_key('lbk_valid_format')
assert result is expected
persistence_mgr.execute_async.assert_awaited_once()

View File

@@ -1 +0,0 @@
# Unit tests for command module

Some files were not shown because too many files have changed in this diff Show More